Skip to content

Commit

Permalink
Merge pull request #4682 from ye-luo/remove-unit
Browse files Browse the repository at this point in the history
Remove xxx_unit cmake targets
  • Loading branch information
ye-luo authored Jul 21, 2023
2 parents 471754b + bf63508 commit b85264d
Show file tree
Hide file tree
Showing 22 changed files with 74 additions and 97 deletions.
5 changes: 0 additions & 5 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,6 @@ if(IS_GIT_PROJECT)
message("Git commit hash: ${GIT_CONFIG_COMMIT_HASH}")
endif()

# For unit tests, enable use for the fake RNG
function(USE_FAKE_RNG TARGET)
target_compile_definitions(${TARGET} PRIVATE "USE_FAKE_RNG")
endfunction()

add_subdirectory(io)
add_subdirectory(einspline)
add_subdirectory(Containers)
Expand Down
5 changes: 0 additions & 5 deletions src/Estimators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,11 @@ set(QMCEST_SRC
####################################
if(USE_OBJECT_TARGET)
add_library(qmcestimators OBJECT ${QMCEST_SRC})
add_library(qmcestimators_unit OBJECT ${QMCEST_SRC})
else()
add_library(qmcestimators ${QMCEST_SRC})
add_library(qmcestimators_unit ${QMCEST_SRC})
endif()
use_fake_rng(qmcestimators_unit)

target_include_directories(qmcestimators PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
target_include_directories(qmcestimators_unit PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
target_link_libraries(qmcestimators PUBLIC containers qmcham qmcparticle qmcutil)
target_link_libraries(qmcestimators_unit PUBLIC containers qmcham_unit qmcparticle qmcutil)

add_subdirectory(tests)
15 changes: 7 additions & 8 deletions src/Estimators/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ set(SRCS
test_MagnetizationDensity.cpp)

add_executable(${UTEST_EXE} ${SRCS})
use_fake_rng(${UTEST_EXE})
target_link_libraries(${UTEST_EXE} catch_main qmcutil qmcestimators_unit utilities_for_test sposets_for_testing)
target_link_libraries(${UTEST_EXE} catch_main qmcutil qmcestimators utilities_for_test sposets_for_testing)
if(USE_OBJECT_TARGET)
target_link_libraries(
${UTEST_EXE}
qmcutil
qmcestimators_unit
qmcham_unit
qmcestimators
qmcham
qmcwfs
qmcparticle
qmcwfs_omptarget
Expand All @@ -72,9 +71,9 @@ if(HAVE_MPI)
if(USE_OBJECT_TARGET)
target_link_libraries(
${UTEST_EXE}
qmcestimators_unit
qmcham_unit
qmcdriver_unit
qmcestimators
qmcham
qmcdriver
qmcwfs
qmcparticle
qmcwfs_omptarget
Expand All @@ -83,7 +82,7 @@ if(HAVE_MPI)
platform_omptarget_LA
utilities_for_test)
endif()
target_link_libraries(${UTEST_EXE} catch_main qmcestimators_unit)
target_link_libraries(${UTEST_EXE} catch_main qmcestimators)
# Right now the unified driver mpi tests are hard coded for 3 MPI ranks
add_unit_test(${UTEST_NAME} 3 1 $<TARGET_FILE:${UTEST_EXE}>)
endif()
2 changes: 1 addition & 1 deletion src/Estimators/tests/test_EstimatorManagerCrowd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ TEST_CASE("EstimatorManagerCrowd PerParticleHamiltonianLogger integration", "[es
emc.registerListeners(ham_list);

// Setup RNG
RandomGenerator rng;
FakeRandom<OHMMS_PRECISION_FULL> rng;

// Without this QMCHamiltonian::mw_evaluate segfaults
// Because the CoulombPBCAA hamiltonian component has PtclRhoK (StructFact) that is invalid.
Expand Down
2 changes: 1 addition & 1 deletion src/Estimators/tests/test_MomentumDistribution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ TEST_CASE("MomentumDistribution::accumulate", "[estimators]")
auto ref_wfns = convertUPtrToRefVector(wfns);

// Setup RNG
RandomGenerator rng;
FakeRandom<OHMMS_PRECISION_FULL> rng;

// Perform accumulate
md.accumulate(ref_walkers, ref_psets, ref_wfns, rng);
Expand Down
2 changes: 1 addition & 1 deletion src/Estimators/tests/test_PerParticleHamiltonianLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ TEST_CASE("PerParticleHamiltonianLogger_sum", "[estimators]")
for (auto& mwt : multi_walker_talkers)
mwt.reportVector();

RandomGenerator rng;
FakeRandom<OHMMS_PRECISION_FULL> rng;

int crowd_id = 0;
long walker_id = 0;
Expand Down
6 changes: 3 additions & 3 deletions src/Estimators/tests/test_SpinDensityNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void accumulateFromPsets(int ncrowds, SpinDensityNew& sdn, UPtrVector<OperatorEs
auto ref_psets = makeRefVector<ParticleSet>(psets);
auto ref_wfns = makeRefVector<TrialWaveFunction>(wfns);

RandomGenerator rng;
FakeRandom<OHMMS_PRECISION_FULL> rng;

crowd_sdn.accumulate(ref_walkers, ref_psets, ref_wfns, rng);
}
Expand Down Expand Up @@ -116,7 +116,7 @@ void randomUpdateAccumulate(testing::RandomForTest<QMCT::RealType>& rft, UPtrVec
auto ref_psets = makeRefVector<ParticleSet>(psets);
auto ref_wfns = makeRefVector<TrialWaveFunction>(wfns);

RandomGenerator rng;
FakeRandom<OHMMS_PRECISION_FULL> rng;

crowd_sdn.accumulate(ref_walkers, ref_psets, ref_wfns, rng);
}
Expand Down Expand Up @@ -220,7 +220,7 @@ TEST_CASE("SpinDensityNew::accumulate", "[estimators]")
auto ref_psets = makeRefVector<ParticleSet>(psets);
auto ref_wfns = makeRefVector<TrialWaveFunction>(wfns);

RandomGenerator rng;
FakeRandom<OHMMS_PRECISION_FULL> rng;

sdn.accumulate(ref_walkers, ref_psets, ref_wfns, rng);

Expand Down
1 change: 0 additions & 1 deletion src/Particle/ParticleBase/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ set(UTEST_EXE test_${SRC_DIR})
set(UTEST_NAME deterministic-unit_test_${SRC_DIR})

add_executable(${UTEST_EXE} test_particle_attrib.cpp test_random_seq.cpp test_attrib_ops.cpp)
use_fake_rng(${UTEST_EXE})
target_link_libraries(${UTEST_EXE} catch_main qmcparticle)
if(USE_OBJECT_TARGET)
target_link_libraries(${UTEST_EXE} qmcutil qmcparticle_omptarget)
Expand Down
7 changes: 0 additions & 7 deletions src/QMCDrivers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,17 @@ endif(BUILD_LMYENGINE_INTERFACE)
####################################
if(USE_OBJECT_TARGET)
add_library(qmcdriver OBJECT ${QMCDRIVERS})
add_library(qmcdriver_unit OBJECT ${QMCDRIVERS})
else()
add_library(qmcdriver ${QMCDRIVERS})
add_library(qmcdriver_unit ${QMCDRIVERS})
endif()
use_fake_rng(qmcdriver_unit)

target_include_directories(qmcdriver PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
target_include_directories(qmcdriver_unit PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")

target_link_libraries(qmcdriver PUBLIC qmcham qmcestimators)
target_link_libraries(qmcdriver_unit PUBLIC qmcham_unit qmcestimators_unit)

target_link_libraries(qmcdriver PRIVATE platform_LA Boost::boost)
target_link_libraries(qmcdriver_unit PRIVATE platform_LA Boost::boost)
if(BUILD_LMYENGINE_INTERFACE)
target_link_libraries(qmcdriver PRIVATE formic_utils)
target_link_libraries(qmcdriver_unit PRIVATE formic_utils)
endif()

if(BUILD_UNIT_TESTS)
Expand Down
17 changes: 5 additions & 12 deletions src/QMCDrivers/DMC/DMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "Concurrency/OpenMP.h"
#include "Utilities/Timer.h"
#include "Utilities/RunTimeManager.h"
#include "RandomNumberControl.h"
#include "Utilities/ProgressReportEngine.h"
#include "Utilities/qmc_common.h"
#include "Utilities/FairDivide.h"
Expand All @@ -45,9 +44,11 @@ DMC::DMC(const ProjectData& project_data,
MCWalkerConfiguration& w,
TrialWaveFunction& psi,
QMCHamiltonian& h,
UPtrVector<RandomBase<QMCTraits::FullPrecRealType>>& rngs,
Communicate* comm,
bool enable_profiling)
: QMCDriver(project_data, w, psi, h, comm, "DMC", enable_profiling),
rngs_(rngs),
KillNodeCrossing(0),
BranchInterval(-1),
L2("no"),
Expand Down Expand Up @@ -129,12 +130,8 @@ void DMC::resetUpdateEngines()
#if !defined(REMOVE_TRACEMANAGER)
traceClones[ip] = Traces->makeClone();
#endif
#ifdef USE_FAKE_RNG
Rng[ip] = std::make_unique<FakeRandom<QMCTraits::FullPrecRealType>>();
#else
Rng[ip] = RandomNumberControl::Children[ip]->makeClone();
Rng[ip] = rngs_[ip]->makeClone();
hClones[ip]->setRandomGenerator(Rng[ip].get());
#endif
if (W.isSpinor())
{
spinor = true;
Expand Down Expand Up @@ -299,10 +296,8 @@ bool DMC::run()
block++;
if (DumpConfig && block % Period4CheckPoint == 0)
{
#ifndef USE_FAKE_RNG
for (int ip = 0; ip < NumThreads; ip++)
RandomNumberControl::Children[ip] = Rng[ip]->makeClone();
#endif
rngs_[ip] = Rng[ip]->makeClone();
}
recordBlock(block);
dmc_loop.stop();
Expand All @@ -323,10 +318,8 @@ bool DMC::run()

} while (block < nBlocks);

#ifndef USE_FAKE_RNG
for (int ip = 0; ip < NumThreads; ip++)
RandomNumberControl::Children[ip] = Rng[ip]->makeClone();
#endif
rngs_[ip] = Rng[ip]->makeClone();
Estimators->stop();
for (int ip = 0; ip < NumThreads; ++ip)
Movers[ip]->stopRun2();
Expand Down
3 changes: 3 additions & 0 deletions src/QMCDrivers/DMC/DMC.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DMC : public QMCDriver, public CloneManager
MCWalkerConfiguration& w,
TrialWaveFunction& psi,
QMCHamiltonian& h,
UPtrVector<RandomBase<QMCTraits::FullPrecRealType>>& rngs,
Communicate* comm,
bool enable_profiling);

Expand All @@ -44,6 +45,8 @@ class DMC : public QMCDriver, public CloneManager
QMCRunType getRunType() override { return QMCRunType::DMC; }

private:
//
UPtrVector<RandomBase<QMCTraits::FullPrecRealType>>& rngs_;
///Index to determine what to do when node crossing is detected
// does not appear to be used
IndexType KillNodeCrossing;
Expand Down
3 changes: 2 additions & 1 deletion src/QMCDrivers/DMC/DMCFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "DMCFactory.h"
#include "QMCDrivers/DMC/DMC.h"
#include "Concurrency/OpenMP.h"
#include "RandomNumberControl.h"

//#define PETA_DMC_TEST
namespace qmcplusplus
Expand All @@ -27,7 +28,7 @@ std::unique_ptr<QMCDriver> DMCFactory::create(const ProjectData& project_data,
Communicate* comm,
bool enable_profiling)
{
auto qmc = std::make_unique<DMC>(project_data, w, psi, h, comm, enable_profiling);
auto qmc = std::make_unique<DMC>(project_data, w, psi, h, RandomNumberControl::Children, comm, enable_profiling);
qmc->setUpdateMode(PbyPUpdate);
return qmc;
}
Expand Down
16 changes: 5 additions & 11 deletions src/QMCDrivers/VMC/VMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "QMCDrivers/VMC/VMCUpdateAll.h"
#include "QMCDrivers/VMC/SOVMCUpdatePbyP.h"
#include "QMCDrivers/VMC/SOVMCUpdateAll.h"
#include "RandomNumberControl.h"
#include "Concurrency/OpenMP.h"
#include "Message/CommOperators.h"
#include "Utilities/RunTimeManager.h"
Expand All @@ -41,9 +40,10 @@ VMC::VMC(const ProjectData& project_data,
MCWalkerConfiguration& w,
TrialWaveFunction& psi,
QMCHamiltonian& h,
UPtrVector<RandomBase<QMCTraits::FullPrecRealType>>& rngs,
Communicate* comm,
bool enable_profiling)
: QMCDriver(project_data, w, psi, h, comm, "VMC", enable_profiling), UseDrift("yes")
: QMCDriver(project_data, w, psi, h, comm, "VMC", enable_profiling), UseDrift("yes"), rngs_(rngs)
{
RootName = "vmc";
qmc_driver_mode.set(QMC_UPDATE_MODE, 1);
Expand Down Expand Up @@ -132,10 +132,8 @@ bool VMC::run()
Traces->stopRun();
#endif
//copy back the random states
#ifndef USE_FAKE_RNG
for (int ip = 0; ip < NumThreads; ++ip)
RandomNumberControl::Children[ip] = Rng[ip]->makeClone();
#endif
rngs_[ip] = Rng[ip]->makeClone();
///write samples to a file
bool wrotesamples = DumpConfig;
if (DumpConfig)
Expand Down Expand Up @@ -182,11 +180,7 @@ void VMC::resetRun()
#if !defined(REMOVE_TRACEMANAGER)
traceClones[ip] = Traces->makeClone();
#endif
#ifdef USE_FAKE_RNG
Rng[ip] = std::make_unique<FakeRandom<double>>();
#else
Rng[ip] = RandomNumberControl::Children[ip]->makeClone();
#endif
Rng[ip] = rngs_[ip]->makeClone();
hClones[ip]->setRandomGenerator(Rng[ip].get());
if (W.isSpinor())
{
Expand Down Expand Up @@ -287,7 +281,7 @@ void VMC::resetRun()
qmc_common.memory_allocated += W.getActiveWalkers() * W[0]->DataSet.byteSize();
qmc_common.print_memory_change("VMC::resetRun", before);
}

for (int ip = 0; ip < NumThreads; ++ip)
wClones[ip]->clearEnsemble();
if (nSamplesPerThread)
Expand Down
4 changes: 4 additions & 0 deletions src/QMCDrivers/VMC/VMC.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,21 @@ class VMC : public QMCDriver, public CloneManager
MCWalkerConfiguration& w,
TrialWaveFunction& psi,
QMCHamiltonian& h,
UPtrVector<RandomBase<QMCTraits::FullPrecRealType>>& rngs,
Communicate* comm,
bool enable_profiling);
bool run() override;
bool put(xmlNodePtr cur) override;
QMCRunType getRunType() override { return QMCRunType::VMC; }

private:
int prevSteps;
int prevStepsBetweenSamples;

///option to enable/disable drift equation or RN for VMC
std::string UseDrift;
//
UPtrVector<RandomBase<QMCTraits::FullPrecRealType>>& rngs_;
///check the run-time environments
void resetRun();
///copy constructor
Expand Down
3 changes: 2 additions & 1 deletion src/QMCDrivers/VMC/VMCFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "QMCDrivers/VMC/VMC.h"
#include "QMCDrivers/QMCDriverInterface.h"
#include "QMCDrivers/CorrelatedSampling/CSVMC.h"
#include "RandomNumberControl.h"
#if defined(QMC_BUILD_COMPLETE)
//REMOVE Broken warping
//#if !defined(QMC_COMPLEX)
Expand All @@ -40,7 +41,7 @@ std::unique_ptr<QMCDriverInterface> VMCFactory::create(const ProjectData& projec
std::unique_ptr<QMCDriverInterface> qmc;
if (VMCMode == 0 || VMCMode == 1) //(0,0,0) (0,0,1)
{
qmc = std::make_unique<VMC>(project_data, w, psi, h, comm, enable_profiling);
qmc = std::make_unique<VMC>(project_data, w, psi, h, RandomNumberControl::Children, comm, enable_profiling);
}
else if (VMCMode == 2 || VMCMode == 3)
{
Expand Down
3 changes: 2 additions & 1 deletion src/QMCDrivers/WFOpt/QMCFixedSampleLinearOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "Particle/HDFWalkerIO.h"
#include "OhmmsData/AttributeSet.h"
#include "Message/CommOperators.h"
#include "RandomNumberControl.h"
#include "QMCDrivers/WFOpt/QMCCostFunctionBase.h"
#include "QMCDrivers/WFOpt/QMCCostFunction.h"
#include "QMCDrivers/VMC/VMC.h"
Expand Down Expand Up @@ -647,7 +648,7 @@ bool QMCFixedSampleLinearOptimize::processOptXML(xmlNodePtr opt_xml, const std::

// Destroy old object to stop timer to correctly order timer with object lifetime scope
vmcEngine.reset(nullptr);
vmcEngine = std::make_unique<VMC>(project_data_, W, Psi, H, myComm, false);
vmcEngine = std::make_unique<VMC>(project_data_, W, Psi, H, RandomNumberControl::Children, myComm, false);
vmcEngine->setUpdateMode(vmcMove[0] == 'p');


Expand Down
Loading

0 comments on commit b85264d

Please sign in to comment.