From 85b18c591bca1daf167e363ddd45d6973dff75ab Mon Sep 17 00:00:00 2001 From: Cody Melton Date: Thu, 24 Aug 2023 16:50:30 -0600 Subject: [PATCH] add blas for SplineC2C applyRotation --- .../BsplineFactory/SplineC2C.cpp | 54 ++++++++++++------- .../BsplineFactory/SplineC2C.h | 2 +- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp b/src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp index 7f4b1d1bd5..dfa33f8da9 100644 --- a/src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp +++ b/src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp @@ -17,6 +17,7 @@ #include "spline2/MultiBsplineEval.hpp" #include "QMCWaveFunctions/BsplineFactory/contraction_helper.hpp" #include "CPU/math.hpp" +#include "CPU/BLAS.hpp" namespace qmcplusplus { @@ -57,7 +58,7 @@ void SplineC2C::storeParamsBeforeRotation() { const auto spline_ptr = SplineInst->getSplinePtr(); const auto coefs_tot_size = spline_ptr->coefs_size; - coef_copy_ = std::make_shared>(coefs_tot_size); + coef_copy_ = std::make_shared>(coefs_tot_size); std::copy_n(spline_ptr->coefs, coefs_tot_size, coef_copy_->begin()); } @@ -120,27 +121,40 @@ void SplineC2C::applyRotation(const ValueMatrix& rot_mat, bool use_stored_co std::copy_n(spl_coefs, coefs_tot_size, coef_copy_->begin()); } - for (int i = 0; i < basis_set_size; i++) - for (int j = 0; j < OrbitalSetSize; j++) - { - // cur_elem points to the real componend of the coefficient. - // Imag component is adjacent in memory. - const auto cur_elem = Nsplines * i + 2 * j; - ST newval_r{0.}; - ST newval_i{0.}; - for (auto k = 0; k < OrbitalSetSize; k++) + if constexpr (std::is_same_v) + { + //if ST is double, go ahead and use blas to make things faster + //Note that Nsplines needs to be divided by 2 since spl_coefs and coef_copy_ are stored as reals. + //Also casting them as ValueType so they are complex to do the correct gemm + BLAS::gemm('N', 'N', OrbitalSetSize, basis_set_size, OrbitalSetSize, ST(1.0), rot_mat.data(), OrbitalSetSize, + (ValueType*)(*coef_copy_).data(), Nsplines / 2, ST(0.0), (ValueType*)spl_coefs, Nsplines / 2); + } + else + { + // if ST is float, RealType is double and ValueType is std::complex for C2C + // Just use naive matrix multiplication in order to avoid losing precision on rotation matrix + for (int i = 0; i < basis_set_size; i++) + for (int j = 0; j < OrbitalSetSize; j++) { - const auto index = Nsplines * i + 2 * k; - ST zr = (*coef_copy_)[index]; - ST zi = (*coef_copy_)[index + 1]; - ST wr = rot_mat[k][j].real(); - ST wi = rot_mat[k][j].imag(); - newval_r += zr * wr - zi * wi; - newval_i += zr * wi + zi * wr; + // cur_elem points to the real componend of the coefficient. + // Imag component is adjacent in memory. + const auto cur_elem = Nsplines * i + 2 * j; + ST newval_r{0.}; + ST newval_i{0.}; + for (auto k = 0; k < OrbitalSetSize; k++) + { + const auto index = Nsplines * i + 2 * k; + ST zr = (*coef_copy_)[index]; + ST zi = (*coef_copy_)[index + 1]; + ST wr = rot_mat[k][j].real(); + ST wi = rot_mat[k][j].imag(); + newval_r += zr * wr - zi * wi; + newval_i += zr * wi + zi * wr; + } + spl_coefs[cur_elem] = newval_r; + spl_coefs[cur_elem + 1] = newval_i; } - spl_coefs[cur_elem] = newval_r; - spl_coefs[cur_elem + 1] = newval_i; - } + } } template diff --git a/src/QMCWaveFunctions/BsplineFactory/SplineC2C.h b/src/QMCWaveFunctions/BsplineFactory/SplineC2C.h index af082e0cea..9410e80cfb 100644 --- a/src/QMCWaveFunctions/BsplineFactory/SplineC2C.h +++ b/src/QMCWaveFunctions/BsplineFactory/SplineC2C.h @@ -64,7 +64,7 @@ class SplineC2C : public BsplineSet std::shared_ptr> SplineInst; ///Copy of original splines for orbital rotation - std::shared_ptr> coef_copy_; + std::shared_ptr> coef_copy_; vContainer_type mKK; VectorSoaContainer myKcart;