diff --git a/src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp b/src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp index 7f4b1d1bd5..8a3cd77c60 100644 --- a/src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp +++ b/src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp @@ -17,6 +17,7 @@ #include "spline2/MultiBsplineEval.hpp" #include "QMCWaveFunctions/BsplineFactory/contraction_helper.hpp" #include "CPU/math.hpp" +#include "CPU/BLAS.hpp" namespace qmcplusplus { @@ -57,7 +58,7 @@ void SplineC2C::storeParamsBeforeRotation() { const auto spline_ptr = SplineInst->getSplinePtr(); const auto coefs_tot_size = spline_ptr->coefs_size; - coef_copy_ = std::make_shared>(coefs_tot_size); + coef_copy_ = std::make_shared>(coefs_tot_size); std::copy_n(spline_ptr->coefs, coefs_tot_size, coef_copy_->begin()); } @@ -120,27 +121,41 @@ void SplineC2C::applyRotation(const ValueMatrix& rot_mat, bool use_stored_co std::copy_n(spl_coefs, coefs_tot_size, coef_copy_->begin()); } - for (int i = 0; i < basis_set_size; i++) - for (int j = 0; j < OrbitalSetSize; j++) - { - // cur_elem points to the real componend of the coefficient. - // Imag component is adjacent in memory. - const auto cur_elem = Nsplines * i + 2 * j; - ST newval_r{0.}; - ST newval_i{0.}; - for (auto k = 0; k < OrbitalSetSize; k++) + if constexpr (std::is_same_v) + { + //if ST is double, go ahead and use blas to make things faster + //Note that Nsplines needs to be divided by 2 since spl_coefs and coef_copy_ are stored as reals. + //Also casting them as ValueType so they are complex to do the correct gemm + BLAS::gemm('N', 'N', OrbitalSetSize, basis_set_size, OrbitalSetSize, ValueType(1.0, 0.0), rot_mat.data(), + OrbitalSetSize, (ValueType*)coef_copy_->data(), Nsplines / 2, ValueType(0.0, 0.0), + (ValueType*)spl_coefs, Nsplines / 2); + } + else + { + // if ST is float, RealType is double and ValueType is std::complex for C2C + // Just use naive matrix multiplication in order to avoid losing precision on rotation matrix + for (IndexType i = 0; i < basis_set_size; i++) + for (IndexType j = 0; j < OrbitalSetSize; j++) { - const auto index = Nsplines * i + 2 * k; - ST zr = (*coef_copy_)[index]; - ST zi = (*coef_copy_)[index + 1]; - ST wr = rot_mat[k][j].real(); - ST wi = rot_mat[k][j].imag(); - newval_r += zr * wr - zi * wi; - newval_i += zr * wi + zi * wr; + // cur_elem points to the real componend of the coefficient. + // Imag component is adjacent in memory. + const auto cur_elem = Nsplines * i + 2 * j; + ST newval_r{0.}; + ST newval_i{0.}; + for (IndexType k = 0; k < OrbitalSetSize; k++) + { + const auto index = Nsplines * i + 2 * k; + ST zr = (*coef_copy_)[index]; + ST zi = (*coef_copy_)[index + 1]; + ST wr = rot_mat[k][j].real(); + ST wi = rot_mat[k][j].imag(); + newval_r += zr * wr - zi * wi; + newval_i += zr * wi + zi * wr; + } + spl_coefs[cur_elem] = newval_r; + spl_coefs[cur_elem + 1] = newval_i; } - spl_coefs[cur_elem] = newval_r; - spl_coefs[cur_elem + 1] = newval_i; - } + } } template diff --git a/src/QMCWaveFunctions/BsplineFactory/SplineC2C.h b/src/QMCWaveFunctions/BsplineFactory/SplineC2C.h index af082e0cea..9410e80cfb 100644 --- a/src/QMCWaveFunctions/BsplineFactory/SplineC2C.h +++ b/src/QMCWaveFunctions/BsplineFactory/SplineC2C.h @@ -64,7 +64,7 @@ class SplineC2C : public BsplineSet std::shared_ptr> SplineInst; ///Copy of original splines for orbital rotation - std::shared_ptr> coef_copy_; + std::shared_ptr> coef_copy_; vContainer_type mKK; VectorSoaContainer myKcart; diff --git a/src/QMCWaveFunctions/BsplineFactory/SplineR2R.cpp b/src/QMCWaveFunctions/BsplineFactory/SplineR2R.cpp index ebe548dd25..5b0fa59ed3 100644 --- a/src/QMCWaveFunctions/BsplineFactory/SplineR2R.cpp +++ b/src/QMCWaveFunctions/BsplineFactory/SplineR2R.cpp @@ -17,6 +17,7 @@ #include "SplineR2R.h" #include "spline2/MultiBsplineEval.hpp" #include "QMCWaveFunctions/BsplineFactory/contraction_helper.hpp" +#include "Platforms/CPU/BLAS.hpp" namespace qmcplusplus { @@ -56,7 +57,7 @@ void SplineR2R::storeParamsBeforeRotation() { const auto spline_ptr = SplineInst->getSplinePtr(); const auto coefs_tot_size = spline_ptr->coefs_size; - coef_copy_ = std::make_shared>(coefs_tot_size); + coef_copy_ = std::make_shared>(coefs_tot_size); std::copy_n(spline_ptr->coefs, coefs_tot_size, coef_copy_->begin()); } @@ -120,20 +121,28 @@ void SplineR2R::applyRotation(const ValueMatrix& rot_mat, bool use_stored_co std::copy_n(spl_coefs, coefs_tot_size, coef_copy_->begin()); } - // Apply rotation the dumb way b/c I can't get BLAS::gemm to work... - for (auto i = 0; i < BasisSetSize; i++) + + if constexpr (std::is_same_v) { - for (auto j = 0; j < OrbitalSetSize; j++) - { - const auto cur_elem = Nsplines * i + j; - auto newval{0.}; - for (auto k = 0; k < OrbitalSetSize; k++) + //Here, ST should be equal to ValueType, which will be double for R2R. Using BLAS to make things faster + BLAS::gemm('N', 'N', OrbitalSetSize, BasisSetSize, OrbitalSetSize, ST(1.0), rot_mat.data(), OrbitalSetSize, + coef_copy_->data(), Nsplines, ST(0.0), spl_coefs, Nsplines); + } + else + { + //Here, ST is float but ValueType is double for R2R. Due to issues with type conversions, just doing naive matrix multiplication in this case to not lose precision on rot_mat + for (IndexType i = 0; i < BasisSetSize; i++) + for (IndexType j = 0; j < OrbitalSetSize; j++) { - const auto index = i * Nsplines + k; - newval += (*coef_copy_)[index] * rot_mat[k][j]; + const auto cur_elem = Nsplines * i + j; + FullPrecValueType newval{0.}; + for (IndexType k = 0; k < OrbitalSetSize; k++) + { + const auto index = i * Nsplines + k; + newval += (*coef_copy_)[index] * rot_mat[k][j]; + } + spl_coefs[cur_elem] = newval; } - spl_coefs[cur_elem] = newval; - } } } diff --git a/src/QMCWaveFunctions/BsplineFactory/SplineR2R.h b/src/QMCWaveFunctions/BsplineFactory/SplineR2R.h index a3ac0f919d..3de6fc33fc 100644 --- a/src/QMCWaveFunctions/BsplineFactory/SplineR2R.h +++ b/src/QMCWaveFunctions/BsplineFactory/SplineR2R.h @@ -59,7 +59,7 @@ class SplineR2R : public BsplineSet std::shared_ptr> SplineInst; ///Copy of original splines for orbital rotation - std::shared_ptr> coef_copy_; + std::shared_ptr> coef_copy_; ///thread private ratios for reduction when using nested threading, numVP x numThread Matrix ratios_private;