Skip to content

Commit

Permalink
Port FakeSPOSet to SPOSetT
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo committed Nov 9, 2023
1 parent bc4a64f commit 5036d7e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 29 deletions.
37 changes: 26 additions & 11 deletions src/QMCWaveFunctions/tests/FakeSPO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
namespace qmcplusplus
{

FakeSPO::FakeSPO() : SPOSet("one_FakeSPO")
template<typename VALUE>
FakeSPO<VALUE>::FakeSPO() : SPOSet("one_FakeSPO")
{
a.resize(3, 3);

Expand Down Expand Up @@ -78,11 +79,20 @@ FakeSPO::FakeSPO() : SPOSet("one_FakeSPO")
gv[3] = TinyVector<ValueType, DIM>(0.4, 0.3, 0.1);
}

std::unique_ptr<SPOSet> FakeSPO::makeClone() const { return std::make_unique<FakeSPO>(*this); }
template<typename VALUE>
std::unique_ptr<SPOSetT<VALUE>> FakeSPO<VALUE>::makeClone() const
{
return std::make_unique<FakeSPO>(*this);
}

void FakeSPO::setOrbitalSetSize(int norbs) { OrbitalSetSize = norbs; }
template<typename VALUE>
void FakeSPO<VALUE>::setOrbitalSetSize(int norbs)
{
OrbitalSetSize = norbs;
}

void FakeSPO::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
template<typename VALUE>
void FakeSPO<VALUE>::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
{
if (iat < 0)
for (int i = 0; i < psi.size(); i++)
Expand All @@ -95,7 +105,8 @@ void FakeSPO::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
psi[i] = a2(iat, i);
}

void FakeSPO::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
template<typename VALUE>
void FakeSPO<VALUE>::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
{
if (OrbitalSetSize == 3)
{
Expand All @@ -115,12 +126,13 @@ void FakeSPO::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradV
}
}

void FakeSPO::evaluate_notranspose(const ParticleSet& P,
int first,
int last,
ValueMatrix& logdet,
GradMatrix& dlogdet,
ValueMatrix& d2logdet)
template<typename VALUE>
void FakeSPO<VALUE>::evaluate_notranspose(const ParticleSet& P,
int first,
int last,
ValueMatrix& logdet,
GradMatrix& dlogdet,
ValueMatrix& d2logdet)
{
if (OrbitalSetSize == 3)
{
Expand All @@ -142,4 +154,7 @@ void FakeSPO::evaluate_notranspose(const ParticleSet& P,
}
}

template class FakeSPO<QMCTraits::QTBase::RealType>;
template class FakeSPO<QMCTraits::QTBase::ComplexType>;

} // namespace qmcplusplus
23 changes: 21 additions & 2 deletions src/QMCWaveFunctions/tests/FakeSPO.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,28 @@
namespace qmcplusplus
{

class FakeSPO : public SPOSet
template<typename VALUE>
class FakeSPO : public SPOSetT<VALUE>
{
public:
enum
{
DIM = OHMMS_DIM,
};
using SPOSet = SPOSetT<VALUE>;
using ValueType = typename SPOSet::ValueType;
using GradType = typename SPOSet::GradType;
using ValueVector = typename SPOSet::ValueVector;
using GradVector = typename SPOSet::GradVector;
using ValueMatrix = typename SPOSet::ValueMatrix;
using GradMatrix = typename SPOSet::GradMatrix;

Matrix<ValueType> a;
Matrix<ValueType> a2;
Vector<ValueType> v;
Matrix<ValueType> v2;

SPOSet::GradVector gv;
GradVector gv;

FakeSPO();
~FakeSPO() override {}
Expand All @@ -47,7 +60,13 @@ class FakeSPO : public SPOSet
ValueMatrix& logdet,
GradMatrix& dlogdet,
ValueMatrix& d2logdet) override;

private:
using SPOSet::OrbitalSetSize;
};

extern template class FakeSPO<QMCTraits::QTBase::RealType>;
extern template class FakeSPO<QMCTraits::QTBase::ComplexType>;

} // namespace qmcplusplus
#endif
12 changes: 6 additions & 6 deletions src/QMCWaveFunctions/tests/test_DiracDeterminant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ void check_matrix(Matrix<T1>& a, Matrix<T2>& b)
template<typename DET>
void test_DiracDeterminant_first(const DetMatInvertor inverter_kind)
{
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 3;
spo_init->setOrbitalSetSize(norb);
DET ddb(std::move(spo_init), 0, norb, 1, inverter_kind);
auto spo = dynamic_cast<FakeSPO*>(ddb.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddb.getPhi());

// occurs in call to registerData
ddb.dpsiV.resize(norb);
Expand Down Expand Up @@ -159,11 +159,11 @@ TEST_CASE("DiracDeterminant_first", "[wavefunction][fermion]")
template<typename DET>
void test_DiracDeterminant_second(const DetMatInvertor inverter_kind)
{
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 4;
spo_init->setOrbitalSetSize(norb);
DET ddb(std::move(spo_init), 0, norb, 1, inverter_kind);
auto spo = dynamic_cast<FakeSPO*>(ddb.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddb.getPhi());

// occurs in call to registerData
ddb.dpsiV.resize(norb);
Expand Down Expand Up @@ -300,12 +300,12 @@ TEST_CASE("DiracDeterminant_second", "[wavefunction][fermion]")
template<typename DET>
void test_DiracDeterminant_delayed_update(const DetMatInvertor inverter_kind)
{
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 4;
spo_init->setOrbitalSetSize(norb);
// maximum delay 2
DET ddc(std::move(spo_init), 0, norb, 2, inverter_kind);
auto spo = dynamic_cast<FakeSPO*>(ddc.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddc.getPhi());

// occurs in call to registerData
ddc.dpsiV.resize(norb);
Expand Down
12 changes: 6 additions & 6 deletions src/QMCWaveFunctions/tests/test_DiracDeterminantBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ template<class DET_ENGINE>
void test_DiracDeterminantBatched_first()
{
using DetType = DiracDeterminantBatched<DET_ENGINE>;
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 3;
spo_init->setOrbitalSetSize(norb);
DetType ddb(std::move(spo_init), 0, norb);
auto spo = dynamic_cast<FakeSPO*>(ddb.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddb.getPhi());

// occurs in call to registerData
ddb.dpsiV.resize(norb);
Expand Down Expand Up @@ -141,11 +141,11 @@ template<class DET_ENGINE>
void test_DiracDeterminantBatched_second()
{
using DetType = DiracDeterminantBatched<DET_ENGINE>;
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 4;
spo_init->setOrbitalSetSize(norb);
DetType ddb(std::move(spo_init), 0, norb);
auto spo = dynamic_cast<FakeSPO*>(ddb.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddb.getPhi());

// occurs in call to registerData
ddb.dpsiV.resize(norb);
Expand Down Expand Up @@ -277,11 +277,11 @@ template<class DET_ENGINE>
void test_DiracDeterminantBatched_delayed_update(int delay_rank, DetMatInvertor matrix_inverter_kind)
{
using DetType = DiracDeterminantBatched<DET_ENGINE>;
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 4;
spo_init->setOrbitalSetSize(norb);
DetType ddc(std::move(spo_init), 0, norb, delay_rank, matrix_inverter_kind);
auto spo = dynamic_cast<FakeSPO*>(ddc.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddc.getPhi());

// occurs in call to registerData
ddc.dpsiV.resize(norb);
Expand Down
8 changes: 4 additions & 4 deletions src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ std::vector<std::vector<QMCTraits::RealType>>& getHistoryParams(RotatedSPOs& rot
// Test using global rotation
TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
{
auto fake_spo = std::make_unique<FakeSPO>();
auto fake_spo = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
fake_spo->setOrbitalSetSize(4);
RotatedSPOs rot("fake_rot", std::move(fake_spo));
int nel = 2;
Expand All @@ -674,7 +674,7 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
rot.writeVariationalParameters(hout);
}

auto fake_spo2 = std::make_unique<FakeSPO>();
auto fake_spo2 = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
fake_spo2->setOrbitalSetSize(4);

RotatedSPOs rot2("fake_rot", std::move(fake_spo2));
Expand Down Expand Up @@ -705,7 +705,7 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
// Test using history list.
TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]")
{
auto fake_spo = std::make_unique<FakeSPO>();
auto fake_spo = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
fake_spo->setOrbitalSetSize(4);
RotatedSPOs rot("fake_rot", std::move(fake_spo));
rot.set_use_global_rotation(false);
Expand All @@ -727,7 +727,7 @@ TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]")
rot.writeVariationalParameters(hout);
}

auto fake_spo2 = std::make_unique<FakeSPO>();
auto fake_spo2 = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
fake_spo2->setOrbitalSetSize(4);

RotatedSPOs rot2("fake_rot", std::move(fake_spo2));
Expand Down

0 comments on commit 5036d7e

Please sign in to comment.