Skip to content

Commit

Permalink
Merge pull request #5040 from ye-luo/add-AccelBLAS_SYCL
Browse files Browse the repository at this point in the history
Add initial AccelBLAS_SYCL
  • Loading branch information
prckent authored Jun 12, 2024
2 parents e9785ff + 57e33d8 commit 007dd1a
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 13 deletions.
13 changes: 9 additions & 4 deletions config/build_alcf_sunspot_icpx.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
#!/bin/bash
# This recipe is intended for ALCF Sunspot https://www.alcf.anl.gov/support-center/aurora-sunspot
# last revision: Mar 29th 2024
# last revision: June 11th 2024
#
# How to invoke this script?
# build_alcf_sunspot_icpx.sh # build all the variants assuming the current directory is the source directory.
# build_alcf_sunspot_icpx.sh <source_dir> # build all the variants with a given source directory <source_dir>
# build_alcf_sunspot_icpx.sh <source_dir> <install_dir> # build all the variants with a given source directory <source_dir> and install to <install_dir>

module load cmake hdf5/1.14.3 boost/1.83.0
module load oneapi/eng-compiler/2023.12.15.002
for module_name in oneapi/release oneapi/eng-compiler
do
if module is-loaded $module_name ; then module unload $module_name; fi
done

module load spack-pe-gcc cmake
module load oneapi/eng-compiler/2024.04.15.002
module load hdf5/1.14.3 boost/1.84.0
module list >& module_list.txt

echo "**********************************"
Expand All @@ -19,7 +24,7 @@ echo "**********************************"

TYPE=Release
Machine=sunspot
Compiler=icpx20231130
Compiler=icpx20240227

if [[ $# -eq 0 ]]; then
source_folder=`pwd`
Expand Down
3 changes: 3 additions & 0 deletions src/Platforms/AccelBLAS.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#if defined(ENABLE_CUDA)
#include "CUDA/AccelBLAS_CUDA.hpp"
#endif
#if defined(ENABLE_SYCL)
#include "SYCL/AccelBLAS_SYCL.hpp"
#endif
#include "OMPTarget/AccelBLAS_OMPTarget.hpp"

#endif
151 changes: 151 additions & 0 deletions src/Platforms/SYCL/AccelBLAS_SYCL.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
//////////////////////////////////////////////////////////////////////////////////////
// This file is distributed under the University of Illinois/NCSA Open Source License.
// See LICENSE file in top directory for details.
//
// Copyright (c) 2024 QMCPACK developers.
//
// File developed by: Ye Luo, [email protected], Argonne National Laboratory
//////////////////////////////////////////////////////////////////////////////////////

#ifndef QMCPLUSPLUS_SYCL_ACCELBLAS_SYCL_H
#define QMCPLUSPLUS_SYCL_ACCELBLAS_SYCL_H

#include "AccelBLASHandle.hpp"
#include "SYCL/QueueSYCL.hpp"
#include "SYCL/syclBLAS.hpp"

namespace qmcplusplus
{
namespace compute
{
template<>
class BLASHandle<PlatformKind::SYCL>
{
public:
BLASHandle(Queue<PlatformKind::SYCL>& queue) : queue_(queue.getNative()) {}
// sycl queue, not owned, reference-only
sycl::queue& queue_;
};

namespace BLAS
{
template<typename T>
inline void gemm(BLASHandle<PlatformKind::SYCL>& handle,
const char transa,
const char transb,
int m,
int n,
int k,
const T& alpha,
const T* A,
int lda,
const T* B,
int ldb,
const T& beta,
T* C,
int ldc)
{
try
{
oneapi::mkl::blas::gemm(handle.queue_, syclBLAS::convertTransEnum(transa), syclBLAS::convertTransEnum(transb), m, n,
k, alpha, A, lda, B, ldb, beta, C, ldc);
}
catch (oneapi::mkl::exception& e)
{
throw std::runtime_error(std::string("oneapi::mkl exception: ") + e.what());
}
}

template<typename T>
inline void gemv_batched(BLASHandle<PlatformKind::SYCL>& handle,
const char trans,
const int m,
const int n,
const T* alpha,
const T* const A[],
const int lda,
const T* const x[],
const int incx,
const T* beta,
T* const y[],
const int incy,
const size_t batch_count)
{}

template<typename T>
inline void ger_batched(BLASHandle<PlatformKind::SYCL>& handle,
const int m,
const int n,
const T* alpha,
const T* const x[],
const int incx,
const T* const y[],
const int incy,
T* const A[],
const int lda,
const size_t batch_count)
{}

template<typename T>
inline void copy_batched(BLASHandle<PlatformKind::SYCL>& handle,
const int n,
const T* const in[],
const int incx,
T* const out[],
const int incy,
const size_t batch_count)
{}

template<typename T>
inline void gemm_batched(BLASHandle<PlatformKind::SYCL>& handle,
const char transa,
const char transb,
syclBLAS::syclBLAS_int m,
syclBLAS::syclBLAS_int n,
syclBLAS::syclBLAS_int k,
const T& alpha,
const T* const A[],
syclBLAS::syclBLAS_int lda,
const T* const B[],
syclBLAS::syclBLAS_int ldb,
const T& beta,
T* const C[],
syclBLAS::syclBLAS_int ldc,
const size_t batch_count)
{
auto trans_a = syclBLAS::convertTransEnum(transa);
auto trans_b = syclBLAS::convertTransEnum(transb);
try
{
#if defined(GEMM_BATCH_SPAN)
sycl::span alpha_span(sycl::malloc_shared<T>(1, handle.queue_), 1);
alpha_span[0] = alpha;
sycl::span beta_span(sycl::malloc_shared<T>(1, handle.queue_), 1);
beta_span[0] = beta;

oneapi::mkl::blas::gemm_batch(handle.queue_, sycl::span{&trans_a, 1}, sycl::span{&trans_b, 1}, sycl::span{&m, 1},
sycl::span{&n, 1}, sycl::span{&k, 1}, alpha_span,
sycl::span{const_cast<const T**>(A), batch_count}, sycl::span{&lda, 1},
sycl::span{const_cast<const T**>(B), batch_count}, sycl::span{&ldb, 1},
beta_span, sycl::span{const_cast<T**>(C), batch_count},
sycl::span{&ldc, 1}, 1, sycl::span{const_cast<size_t*>(&batch_count), 1});
sycl::free(alpha_span.data(), handle.queue_);
sycl::free(beta_span.data(), handle.queue_);
#else
syclBLAS::syclBLAS_int bc = batch_count;
oneapi::mkl::blas::gemm_batch(handle.queue_, &trans_a, &trans_b, &m, &n, &k, const_cast<const T*>(&alpha),
const_cast<const T**>(A), &lda, const_cast<const T**>(B), &ldb,
const_cast<const T*>(&beta), const_cast<T**>(C), &ldc, 1, &bc);
#endif
}
catch (oneapi::mkl::exception& e)
{
throw std::runtime_error(std::string("oneapi::mkl exception: ") + e.what());
}
}

} // namespace BLAS
} // namespace compute
} // namespace qmcplusplus
#undef castNativeType
#endif
2 changes: 1 addition & 1 deletion src/Platforms/SYCL/QueueSYCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template<>
class Queue<PlatformKind::SYCL>
{
public:
Queue() : queue_(createSYCLQueueOnDefaultDevice()) {}
Queue() : queue_(createSYCLInOrderQueueOnDefaultDevice()) {}

// dualspace container
template<class DSC>
Expand Down
8 changes: 7 additions & 1 deletion src/Platforms/SYCL/SYCLruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
namespace qmcplusplus
{
sycl::queue& getSYCLDefaultDeviceDefaultQueue() { return SYCLDeviceManager::getDefaultDeviceDefaultQueue(); }
sycl::queue createSYCLQueueOnDefaultDevice()

sycl::queue createSYCLInOrderQueueOnDefaultDevice()
{
return sycl::queue(getSYCLDefaultDeviceDefaultQueue().get_context(), getSYCLDefaultDeviceDefaultQueue().get_device(),
sycl::property::queue::in_order());
}

sycl::queue createSYCLQueueOnDefaultDevice()
{
return sycl::queue(getSYCLDefaultDeviceDefaultQueue().get_context(), getSYCLDefaultDeviceDefaultQueue().get_device());
}

size_t getSYCLdeviceFreeMem()
{
auto device = getSYCLDefaultDeviceDefaultQueue().get_device();
Expand Down
4 changes: 3 additions & 1 deletion src/Platforms/SYCL/SYCLruntime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ namespace qmcplusplus
{
/// return a reference to the per-device default queue
sycl::queue& getSYCLDefaultDeviceDefaultQueue();
/// create a queue using the default device
/// create an in-order queue using the default device
sycl::queue createSYCLInOrderQueueOnDefaultDevice();
/// create a out-of-order queue using the default device
sycl::queue createSYCLQueueOnDefaultDevice();
/// query free memory on the default device
size_t getSYCLdeviceFreeMem();
Expand Down
5 changes: 0 additions & 5 deletions src/Platforms/SYCL/syclBLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ namespace qmcplusplus
{
namespace syclBLAS
{
inline oneapi::mkl::transpose convertTransEnum(char trans)
{
return trans == 'T' ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
}

template<typename T>
sycl::event gemv(sycl::queue& handle,
const char trans,
Expand Down
15 changes: 15 additions & 0 deletions src/Platforms/SYCL/syclBLAS.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <complex>
#include <sycl/sycl.hpp>
#include <oneapi/mkl/blas.hpp>

namespace qmcplusplus
{
Expand All @@ -24,6 +25,20 @@ using syclBLAS_int = std::int64_t;
using syclBLAS_status = sycl::event;
using syclBLAS_handle = sycl::queue;

inline oneapi::mkl::transpose convertTransEnum(char trans)
{
if (trans == 'N' || trans == 'n')
return oneapi::mkl::transpose::nontrans;
else if (trans == 'T' || trans == 't')
return oneapi::mkl::transpose::trans;
else if (trans == 'C' || trans == 'c')
return oneapi::mkl::transpose::conjtrans;
else
throw std::runtime_error(
"syclBLAS::convertTransEnum trans can only be 'N', 'T', 'C', 'n', 't', 'c'. Input value is " +
std::string(1, trans));
}

template<typename T>
sycl::event gemv(sycl::queue& handle,
const char trans,
Expand Down
4 changes: 4 additions & 0 deletions src/Platforms/tests/test_AccelBLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ TEST_CASE("AccelBLAS", "[BLAS]")
std::cout << "Testing gemm<PlatformKind::CUDA>" << std::endl;
test_gemm_cases<PlatformKind::CUDA>();
#endif
#if defined(ENABLE_SYCL)
std::cout << "Testing gemm<PlatformKind::SYCL>" << std::endl;
test_gemm_cases<PlatformKind::SYCL>();
#endif
#if defined(ENABLE_OFFLOAD)
std::cout << "Testing gemm<PlatformKind::OMPTARGET>" << std::endl;
test_gemm_cases<PlatformKind::OMPTARGET>();
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/Fermion/DelayedUpdateSYCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class DelayedUpdateSYCL

public:
/// default constructor
DelayedUpdateSYCL() : delay_count(0) { m_queue_ = createSYCLQueueOnDefaultDevice(); }
DelayedUpdateSYCL() : delay_count(0) { m_queue_ = createSYCLInOrderQueueOnDefaultDevice(); }

~DelayedUpdateSYCL() { syclSolver::freeBuffer(); }

Expand Down

0 comments on commit 007dd1a

Please sign in to comment.