Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Add an interface to set the engine index of MT2203 engine #1760

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/rng_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ class rng_generator_base {
virtual void set_direction_numbers(
const std::vector<std::uint32_t> &direction_numbers) = 0;

/// Set the engine index of host rng_generator. Only MT2203 engine
/// supports this method.
/// \param engine_idx The engine index.
virtual void set_engine_idx(std::uint32_t engine_idx) = 0;

protected:
/// Construct the host rng_generator.
/// \param queue The queue where the generator should be executed.
Expand All @@ -302,6 +307,7 @@ class rng_generator_base {
std::uint64_t _seed{0};
std::uint32_t _dimensions{1};
std::vector<std::uint32_t> _direction_numbers;
std::uint32_t _engine_idx{0};
};

/// The random number generator on host.
Expand Down Expand Up @@ -364,6 +370,25 @@ class rng_generator : public rng_generator_base {
#endif
}

/// Set the engine index of MT2203 host rng_generator.
/// \param engine_idx The user-defined engine index.
void set_engine_idx(std::uint32_t engine_idx) {
#ifndef __INTEL_MKL__
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) "
"Interfaces Project does not support this API.");
#else
if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::mt2203>) {
if (engine_idx == _engine_idx) {
return;
}
_engine_idx = engine_idx;
_engine = oneapi::mkl::rng::mt2203(*_queue, _seed, _engine_idx);
} else {
throw std::runtime_error("Only MT2203 engine supports this method.");
}
#endif
}

/// Generate unsigned int random number(s) with 'uniform_bits' distribution.
/// \param output The pointer of the first random number.
/// \param n The number of random numbers.
Expand Down
Loading