Skip to content

Commit

Permalink
[SYCL][COMPAT] Add bfloat16 support to several maths ops (#15572)
Browse files Browse the repository at this point in the history
Adds support for `sycl::ext::oneapi::bfloat16` to:
 - `relu`
 - `clamp`
 - `fmax_nan`
 - `fmin_nan`
 - `min`
 - `max`
 - `compare_mask`
 - `unordered_compare_mask`
  • Loading branch information
joeatodd authored Oct 17, 2024
1 parent 3796776 commit 1791115
Show file tree
Hide file tree
Showing 9 changed files with 498 additions and 164 deletions.
70 changes: 60 additions & 10 deletions sycl/doc/syclcompat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,51 @@ second operand, respectively. These three APIs return a single 32-bit value with
the accumulated result, which is unsigned if both operands are `uint32_t` and
signed otherwise.

Various maths functions are defined operate on any floating point types.
`syclcompat::is_floating_point_v` extends the standard library's
`std::is_floating_point_v` to include `sycl::half` and, where available,
`sycl::ext::oneapi::bfloat16`. The current version of SYCLcompat also provides
a specialization of `std::common_type_t` for `sycl::ext::oneapi::bfloat16`,
though this will be moved to the `sycl_ext_oneapi_bfloat16` extension in
future.
```cpp
namespace std {
template <> struct common_type<sycl::ext::oneapi::bfloat16> {
using type = sycl::ext::oneapi::bfloat16;
};
template <>
struct common_type<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16> {
using type = sycl::ext::oneapi::bfloat16;
};
template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> {
using type = sycl::ext::oneapi::bfloat16;
};
template <typename T> struct common_type<T, sycl::ext::oneapi::bfloat16> {
using type = sycl::ext::oneapi::bfloat16;
};
} // namespace std
```
```cpp
namespace syclcompat{
// Trait for extended floating point definition
template <typename T>
struct is_floating_point : std::is_floating_point<T>{};
template <> struct is_floating_point<sycl::half> : std::true_type {};
#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
template <> struct is_floating_point<sycl::ext::oneapi::bfloat16> : std::true_type {};
#endif
template <typename T>
inline constexpr bool is_floating_point_v = is_floating_point<T>::value;
inline unsigned int funnelshift_l(unsigned int low, unsigned int high,
unsigned int shift);
Expand All @@ -1756,11 +1800,9 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a);
// cbrt function wrapper.
template <typename ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
std::is_same_v<ValueT, sycl::half>,
ValueT>
cbrt(ValueT val) {
return sycl::cbrt(static_cast<ValueT>(val));
}
cbrt(ValueT val);
// For floating-point types, `float` or `double` arguments are acceptable.
// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
Expand Down Expand Up @@ -1798,6 +1840,10 @@ template <typename ValueT, typename ValueU>
inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
fmax_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);
template <typename ValueT, typename ValueU>
inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
fmax_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);
// Performs 2 elements comparison and returns the smaller one. If either of
// inputs is NaN, then return NaN.
template <typename ValueT, typename ValueU>
Expand All @@ -1807,6 +1853,10 @@ template <typename ValueT, typename ValueU>
inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
fmin_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);
template <typename ValueT, typename ValueU>
inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
fmin_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);
inline float pow(const float a, const int b) { return sycl::pown(a, b); }
inline double pow(const double a, const int b) { return sycl::pown(a, b); }
Expand Down Expand Up @@ -1867,14 +1917,13 @@ unordered_compare_both(const ValueT a, const ValueT b,
const BinaryOperation binary_op);
template <typename ValueT, class BinaryOperation>
inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
const sycl::vec<ValueT, 2> b,
const BinaryOperation binary_op);
inline std::enable_if_t<ValueT::size() == 2, unsigned>
compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op);
template <typename ValueT, class BinaryOperation>
inline unsigned unordered_compare_mask(const sycl::vec<ValueT, 2> a,
const sycl::vec<ValueT, 2> b,
const BinaryOperation binary_op);
inline std::enable_if_t<ValueT::size() == 2, unsigned>
unordered_compare_mask(const ValueT a, const ValueT b,
const BinaryOperation binary_op);
template <typename S, typename T> inline T vectorized_max(T a, T b);
Expand Down Expand Up @@ -1928,6 +1977,7 @@ inline dot_product_acc_t<T1, T2> dp2a_hi(T1 a, T2 b,
template <typename T1, typename T2>
inline dot_product_acc_t<T1, T2> dp4a(T1 a, T2 b,
dot_product_acc_t<T1, T2> c);
} // namespace syclcompat
```
`vectorized_binary` computes the `BinaryOperation` for two operands,
Expand Down
Loading

0 comments on commit 1791115

Please sign in to comment.