Skip to content

Commit

Permalink
Merge branch 'spl_c2c_rotation_with_blas' into spl_rotation_with_blas
Browse files Browse the repository at this point in the history
  • Loading branch information
camelto2 committed Aug 24, 2023
2 parents c86e331 + 85b18c5 commit 79e3de9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
54 changes: 34 additions & 20 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "spline2/MultiBsplineEval.hpp"
#include "QMCWaveFunctions/BsplineFactory/contraction_helper.hpp"
#include "CPU/math.hpp"
#include "CPU/BLAS.hpp"

namespace qmcplusplus
{
Expand Down Expand Up @@ -57,7 +58,7 @@ void SplineC2C<ST>::storeParamsBeforeRotation()
{
const auto spline_ptr = SplineInst->getSplinePtr();
const auto coefs_tot_size = spline_ptr->coefs_size;
coef_copy_ = std::make_shared<std::vector<RealType>>(coefs_tot_size);
coef_copy_ = std::make_shared<std::vector<ST>>(coefs_tot_size);

std::copy_n(spline_ptr->coefs, coefs_tot_size, coef_copy_->begin());
}
Expand Down Expand Up @@ -120,27 +121,40 @@ void SplineC2C<ST>::applyRotation(const ValueMatrix& rot_mat, bool use_stored_co
std::copy_n(spl_coefs, coefs_tot_size, coef_copy_->begin());
}

for (int i = 0; i < basis_set_size; i++)
for (int j = 0; j < OrbitalSetSize; j++)
{
// cur_elem points to the real componend of the coefficient.
// Imag component is adjacent in memory.
const auto cur_elem = Nsplines * i + 2 * j;
ST newval_r{0.};
ST newval_i{0.};
for (auto k = 0; k < OrbitalSetSize; k++)
if constexpr (std::is_same_v<ST, RealType>)
{
//if ST is double, go ahead and use blas to make things faster
//Note that Nsplines needs to be divided by 2 since spl_coefs and coef_copy_ are stored as reals.
//Also casting them as ValueType so they are complex to do the correct gemm
BLAS::gemm('N', 'N', OrbitalSetSize, basis_set_size, OrbitalSetSize, ST(1.0), rot_mat.data(), OrbitalSetSize,
(ValueType*)(*coef_copy_).data(), Nsplines / 2, ST(0.0), (ValueType*)spl_coefs, Nsplines / 2);
}
else
{
// if ST is float, RealType is double and ValueType is std::complex<double> for C2C
// Just use naive matrix multiplication in order to avoid losing precision on rotation matrix
for (int i = 0; i < basis_set_size; i++)
for (int j = 0; j < OrbitalSetSize; j++)
{
const auto index = Nsplines * i + 2 * k;
ST zr = (*coef_copy_)[index];
ST zi = (*coef_copy_)[index + 1];
ST wr = rot_mat[k][j].real();
ST wi = rot_mat[k][j].imag();
newval_r += zr * wr - zi * wi;
newval_i += zr * wi + zi * wr;
// cur_elem points to the real componend of the coefficient.
// Imag component is adjacent in memory.
const auto cur_elem = Nsplines * i + 2 * j;
ST newval_r{0.};
ST newval_i{0.};
for (auto k = 0; k < OrbitalSetSize; k++)
{
const auto index = Nsplines * i + 2 * k;
ST zr = (*coef_copy_)[index];
ST zi = (*coef_copy_)[index + 1];
ST wr = rot_mat[k][j].real();
ST wi = rot_mat[k][j].imag();
newval_r += zr * wr - zi * wi;
newval_i += zr * wi + zi * wr;
}
spl_coefs[cur_elem] = newval_r;
spl_coefs[cur_elem + 1] = newval_i;
}
spl_coefs[cur_elem] = newval_r;
spl_coefs[cur_elem + 1] = newval_i;
}
}
}

template<typename ST>
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineC2C.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SplineC2C : public BsplineSet
std::shared_ptr<MultiBspline<ST>> SplineInst;

///Copy of original splines for orbital rotation
std::shared_ptr<std::vector<RealType>> coef_copy_;
std::shared_ptr<std::vector<ST>> coef_copy_;

vContainer_type mKK;
VectorSoaContainer<ST, 3> myKcart;
Expand Down

0 comments on commit 79e3de9

Please sign in to comment.