Skip to content

Commit

Permalink
Merge pull request #4710 from camelto2/spl_rotation_with_blas
Browse files Browse the repository at this point in the history
SplineC2C/R2R rotation with BLAS
  • Loading branch information
ye-luo authored Aug 28, 2023
2 parents 283f243 + 91adaa9 commit 6c7b0d3
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 34 deletions.
55 changes: 35 additions & 20 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "spline2/MultiBsplineEval.hpp"
#include "QMCWaveFunctions/BsplineFactory/contraction_helper.hpp"
#include "CPU/math.hpp"
#include "CPU/BLAS.hpp"

namespace qmcplusplus
{
Expand Down Expand Up @@ -57,7 +58,7 @@ void SplineC2C<ST>::storeParamsBeforeRotation()
{
const auto spline_ptr = SplineInst->getSplinePtr();
const auto coefs_tot_size = spline_ptr->coefs_size;
coef_copy_ = std::make_shared<std::vector<RealType>>(coefs_tot_size);
coef_copy_ = std::make_shared<std::vector<ST>>(coefs_tot_size);

std::copy_n(spline_ptr->coefs, coefs_tot_size, coef_copy_->begin());
}
Expand Down Expand Up @@ -120,27 +121,41 @@ void SplineC2C<ST>::applyRotation(const ValueMatrix& rot_mat, bool use_stored_co
std::copy_n(spl_coefs, coefs_tot_size, coef_copy_->begin());
}

for (int i = 0; i < basis_set_size; i++)
for (int j = 0; j < OrbitalSetSize; j++)
{
// cur_elem points to the real componend of the coefficient.
// Imag component is adjacent in memory.
const auto cur_elem = Nsplines * i + 2 * j;
ST newval_r{0.};
ST newval_i{0.};
for (auto k = 0; k < OrbitalSetSize; k++)
if constexpr (std::is_same_v<ST, RealType>)
{
//if ST is double, go ahead and use blas to make things faster
//Note that Nsplines needs to be divided by 2 since spl_coefs and coef_copy_ are stored as reals.
//Also casting them as ValueType so they are complex to do the correct gemm
BLAS::gemm('N', 'N', OrbitalSetSize, basis_set_size, OrbitalSetSize, ValueType(1.0, 0.0), rot_mat.data(),
OrbitalSetSize, (ValueType*)coef_copy_->data(), Nsplines / 2, ValueType(0.0, 0.0),
(ValueType*)spl_coefs, Nsplines / 2);
}
else
{
// if ST is float, RealType is double and ValueType is std::complex<double> for C2C
// Just use naive matrix multiplication in order to avoid losing precision on rotation matrix
for (IndexType i = 0; i < basis_set_size; i++)
for (IndexType j = 0; j < OrbitalSetSize; j++)
{
const auto index = Nsplines * i + 2 * k;
ST zr = (*coef_copy_)[index];
ST zi = (*coef_copy_)[index + 1];
ST wr = rot_mat[k][j].real();
ST wi = rot_mat[k][j].imag();
newval_r += zr * wr - zi * wi;
newval_i += zr * wi + zi * wr;
// cur_elem points to the real componend of the coefficient.
// Imag component is adjacent in memory.
const auto cur_elem = Nsplines * i + 2 * j;
ST newval_r{0.};
ST newval_i{0.};
for (IndexType k = 0; k < OrbitalSetSize; k++)
{
const auto index = Nsplines * i + 2 * k;
ST zr = (*coef_copy_)[index];
ST zi = (*coef_copy_)[index + 1];
ST wr = rot_mat[k][j].real();
ST wi = rot_mat[k][j].imag();
newval_r += zr * wr - zi * wi;
newval_i += zr * wi + zi * wr;
}
spl_coefs[cur_elem] = newval_r;
spl_coefs[cur_elem + 1] = newval_i;
}
spl_coefs[cur_elem] = newval_r;
spl_coefs[cur_elem + 1] = newval_i;
}
}
}

template<typename ST>
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineC2C.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SplineC2C : public BsplineSet
std::shared_ptr<MultiBspline<ST>> SplineInst;

///Copy of original splines for orbital rotation
std::shared_ptr<std::vector<RealType>> coef_copy_;
std::shared_ptr<std::vector<ST>> coef_copy_;

vContainer_type mKK;
VectorSoaContainer<ST, 3> myKcart;
Expand Down
33 changes: 21 additions & 12 deletions src/QMCWaveFunctions/BsplineFactory/SplineR2R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "SplineR2R.h"
#include "spline2/MultiBsplineEval.hpp"
#include "QMCWaveFunctions/BsplineFactory/contraction_helper.hpp"
#include "Platforms/CPU/BLAS.hpp"

namespace qmcplusplus
{
Expand Down Expand Up @@ -56,7 +57,7 @@ void SplineR2R<ST>::storeParamsBeforeRotation()
{
const auto spline_ptr = SplineInst->getSplinePtr();
const auto coefs_tot_size = spline_ptr->coefs_size;
coef_copy_ = std::make_shared<std::vector<RealType>>(coefs_tot_size);
coef_copy_ = std::make_shared<std::vector<ST>>(coefs_tot_size);

std::copy_n(spline_ptr->coefs, coefs_tot_size, coef_copy_->begin());
}
Expand Down Expand Up @@ -120,20 +121,28 @@ void SplineR2R<ST>::applyRotation(const ValueMatrix& rot_mat, bool use_stored_co
std::copy_n(spl_coefs, coefs_tot_size, coef_copy_->begin());
}

// Apply rotation the dumb way b/c I can't get BLAS::gemm to work...
for (auto i = 0; i < BasisSetSize; i++)

if constexpr (std::is_same_v<ST, RealType>)
{
for (auto j = 0; j < OrbitalSetSize; j++)
{
const auto cur_elem = Nsplines * i + j;
auto newval{0.};
for (auto k = 0; k < OrbitalSetSize; k++)
//Here, ST should be equal to ValueType, which will be double for R2R. Using BLAS to make things faster
BLAS::gemm('N', 'N', OrbitalSetSize, BasisSetSize, OrbitalSetSize, ST(1.0), rot_mat.data(), OrbitalSetSize,
coef_copy_->data(), Nsplines, ST(0.0), spl_coefs, Nsplines);
}
else
{
//Here, ST is float but ValueType is double for R2R. Due to issues with type conversions, just doing naive matrix multiplication in this case to not lose precision on rot_mat
for (IndexType i = 0; i < BasisSetSize; i++)
for (IndexType j = 0; j < OrbitalSetSize; j++)
{
const auto index = i * Nsplines + k;
newval += (*coef_copy_)[index] * rot_mat[k][j];
const auto cur_elem = Nsplines * i + j;
FullPrecValueType newval{0.};
for (IndexType k = 0; k < OrbitalSetSize; k++)
{
const auto index = i * Nsplines + k;
newval += (*coef_copy_)[index] * rot_mat[k][j];
}
spl_coefs[cur_elem] = newval;
}
spl_coefs[cur_elem] = newval;
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineR2R.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class SplineR2R : public BsplineSet
std::shared_ptr<MultiBspline<ST>> SplineInst;

///Copy of original splines for orbital rotation
std::shared_ptr<std::vector<RealType>> coef_copy_;
std::shared_ptr<std::vector<ST>> coef_copy_;

///thread private ratios for reduction when using nested threading, numVP x numThread
Matrix<TT> ratios_private;
Expand Down

0 comments on commit 6c7b0d3

Please sign in to comment.