Skip to content

Commit 8e69702

Browse files
authored
[NFC][SYCL] Use bit_cast for bfloat16 casts. (#17256)
bfloat16.hpp defines two help functions to cast between 16-bit integer and bfloat16 types. These fucntions duplicate logic of sycl::bit_cast standard function.
1 parent 7bdb758 commit 8e69702

File tree

2 files changed

+57
-91
lines changed

2 files changed

+57
-91
lines changed

sycl/include/sycl/ext/oneapi/bfloat16.hpp

+12-34
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <sycl/aliases.hpp> // for half
12+
#include <sycl/bit_cast.hpp> // for bit_cast
1213
#include <sycl/detail/defines_elementary.hpp> // for __DPCPP_SYCL_EXTERNAL
1314
#include <sycl/half_type.hpp> // for half
1415

@@ -51,11 +52,6 @@ class bfloat16;
5152

5253
namespace detail {
5354
using Bfloat16StorageT = uint16_t;
54-
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value);
55-
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value);
56-
// Class to convert different data types to Bfloat16
57-
// with different rounding modes.
58-
class ConvertToBfloat16;
5955

6056
template <int N> void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) {
6157
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
@@ -84,12 +80,6 @@ class bfloat16 {
8480
protected:
8581
detail::Bfloat16StorageT value;
8682

87-
friend inline detail::Bfloat16StorageT
88-
detail::bfloat16ToBits(const bfloat16 &Value);
89-
friend inline bfloat16
90-
detail::bitsToBfloat16(const detail::Bfloat16StorageT Value);
91-
friend class detail::ConvertToBfloat16;
92-
9383
public:
9484
bfloat16() = default;
9585
~bfloat16() = default;
@@ -187,7 +177,7 @@ class bfloat16 {
187177
(__SYCL_CUDA_ARCH__ >= 800)
188178
detail::Bfloat16StorageT res;
189179
asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value));
190-
return detail::bitsToBfloat16(res);
180+
return bit_cast<bfloat16>(res);
191181
#elif defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
192182
return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)};
193183
#else
@@ -294,19 +284,6 @@ template <int N> void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) {
294284
#endif
295285
}
296286

297-
// Helper function for getting the internal representation of a bfloat16.
298-
inline Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value) {
299-
return Value.value;
300-
}
301-
302-
// Helper function for creating a float16 from a value with the same type as the
303-
// internal representation.
304-
inline bfloat16 bitsToBfloat16(const Bfloat16StorageT Value) {
305-
bfloat16 res;
306-
res.value = Value;
307-
return res;
308-
}
309-
310287
// Class to convert different data types to Bfloat16
311288
// with different rounding modes.
312289
class ConvertToBfloat16 {
@@ -348,15 +325,15 @@ class ConvertToBfloat16 {
348325
// +/-infinity and NAN
349326
if (bf16_exp == 0xFF) {
350327
if (!f_mant)
351-
return bitsToBfloat16(bf16_sign ? 0xFF80 : 0x7F80);
328+
return bit_cast<bfloat16, uint16_t>(bf16_sign ? 0xFF80 : 0x7F80);
352329
else
353-
return bitsToBfloat16((bf16_sign << 15) | (bf16_exp << 7) |
354-
bf16_mant);
330+
return bit_cast<bfloat16, uint16_t>((bf16_sign << 15) |
331+
(bf16_exp << 7) | bf16_mant);
355332
}
356333

357334
// +/-0
358335
if (!bf16_exp && !f_mant) {
359-
return bitsToBfloat16(bf16_sign ? 0x8000 : 0x0);
336+
return bit_cast<bfloat16, uint16_t>(bf16_sign ? 0x8000 : 0x0);
360337
}
361338

362339
uint16_t mant_discard = static_cast<uint16_t>(f_mant & 0xFFFF);
@@ -385,7 +362,8 @@ class ConvertToBfloat16 {
385362
bf16_exp++;
386363
}
387364

388-
return bitsToBfloat16((bf16_sign << 15) | (bf16_exp << 7) | bf16_mant);
365+
return bit_cast<bfloat16, uint16_t>((bf16_sign << 15) | (bf16_exp << 7) |
366+
bf16_mant);
389367
}
390368
}
391369

@@ -401,7 +379,7 @@ class ConvertToBfloat16 {
401379
size_t msb_pos = get_msb_pos(u);
402380
// return half representation for 1
403381
if (msb_pos == 0)
404-
return bitsToBfloat16(0x3F80);
382+
return bit_cast<bfloat16, uint16_t>(0x3F80);
405383

406384
T mant = u & ((static_cast<T>(1) << msb_pos) - 1);
407385
// Unsigned integral value can be represented by 1.mant * (2^msb_pos),
@@ -442,7 +420,7 @@ class ConvertToBfloat16 {
442420
}
443421

444422
b_exp += 127;
445-
return bitsToBfloat16((b_exp << 7) | b_mant);
423+
return bit_cast<bfloat16, uint16_t>((b_exp << 7) | b_mant);
446424
}
447425

448426
// Helper function to get BF16 from signed integral data types.
@@ -459,7 +437,7 @@ class ConvertToBfloat16 {
459437
UTy ui = (i > 0) ? static_cast<UTy>(i) : static_cast<UTy>(-i);
460438
size_t msb_pos = get_msb_pos<UTy>(ui);
461439
if (msb_pos == 0)
462-
return bitsToBfloat16(b_sign ? 0xBF80 : 0x3F80);
440+
return bit_cast<bfloat16, uint16_t>(b_sign ? 0xBF80 : 0x3F80);
463441
UTy mant = ui & ((static_cast<UTy>(1) << msb_pos) - 1);
464442

465443
uint16_t b_exp = msb_pos;
@@ -495,7 +473,7 @@ class ConvertToBfloat16 {
495473
b_mant = 0;
496474
}
497475
b_exp += 127;
498-
return bitsToBfloat16(b_sign | (b_exp << 7) | b_mant);
476+
return bit_cast<bfloat16, uint16_t>(b_sign | (b_exp << 7) | b_mant);
499477
}
500478

501479
// Helper function to get BF16 from double with RTE rounding modes.

sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp

+45-57
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
#pragma once
1010

11-
#include <sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
12-
#include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy
11+
#include <sycl/bit_cast.hpp> // for sycl::bit_cast
12+
#include <sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
13+
#include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy
1314
#include <sycl/detail/vector_convert.hpp>
14-
#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16, bfloat16ToBits
15+
#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
1516
#include <sycl/marray.hpp> // for marray
1617

1718
#include <cstring> // for size_t
@@ -46,7 +47,7 @@ constexpr int num_elements_v = sycl::detail::num_elements<T>::value;
4647
// significand has non-zero bits.
4748
template <typename T>
4849
std::enable_if_t<std::is_same_v<T, bfloat16>, bool> isnan(T x) {
49-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
50+
uint16_t XBits = bit_cast<uint16_t>(x);
5051
return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false;
5152
}
5253

@@ -90,15 +91,15 @@ template <typename T>
9091
std::enable_if_t<std::is_same_v<T, bfloat16>, T> fabs(T x) {
9192
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
9293
(__SYCL_CUDA_ARCH__ >= 800)
93-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
94-
return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
94+
uint16_t XBits = bit_cast<uint16_t>(x);
95+
return bit_cast<bfloat16>(__clc_fabs(XBits));
9596
#else
9697
if (!isnan(x)) {
97-
const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000;
98-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
98+
constexpr uint16_t SignMask = 0x8000;
99+
uint16_t XBits = bit_cast<uint16_t>(x);
99100
x = ((XBits & SignMask) == SignMask)
100-
? oneapi::detail::bitsToBfloat16(XBits & ~SignMask)
101-
: x;
101+
? bit_cast<bfloat16, uint16_t>(XBits & ~SignMask)
102+
: bit_cast<bfloat16>(x);
102103
}
103104
return x;
104105
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -116,9 +117,8 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
116117
}
117118

118119
if (N % 2) {
119-
oneapi::detail::Bfloat16StorageT XBits =
120-
oneapi::detail::bfloat16ToBits(x[N - 1]);
121-
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
120+
uint16_t XBits = bit_cast<uint16_t>(x[N - 1]);
121+
res[N - 1] = bit_cast<bfloat16>(__clc_fabs(XBits));
122122
}
123123
#else
124124
for (size_t i = 0; i < N; i++) {
@@ -154,25 +154,22 @@ template <typename T>
154154
std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmin(T x, T y) {
155155
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
156156
(__SYCL_CUDA_ARCH__ >= 800)
157-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
158-
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
159-
return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
157+
uint16_t XBits = bit_cast<uint16_t>(x);
158+
uint16_t YBits = bit_cast<uint16_t>(y);
159+
return bit_cast<bfloat16>(__clc_fmin(XBits, YBits));
160160
#else
161-
static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
161+
constexpr uint16_t CanonicalNan = 0x7FC0;
162162
if (isnan(x) && isnan(y))
163-
return oneapi::detail::bitsToBfloat16(CanonicalNan);
163+
return bit_cast<bfloat16>(CanonicalNan);
164164

165165
if (isnan(x))
166166
return y;
167167
if (isnan(y))
168168
return x;
169-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
170-
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
171-
if (((XBits | YBits) ==
172-
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
173-
!(XBits & YBits))
174-
return oneapi::detail::bitsToBfloat16(
175-
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000));
169+
uint16_t XBits = bit_cast<uint16_t>(x);
170+
uint16_t YBits = bit_cast<uint16_t>(y);
171+
if (((XBits | YBits) == static_cast<uint16_t>(0x8000)) && !(XBits & YBits))
172+
return bit_cast<bfloat16>(static_cast<uint16_t>(0x8000));
176173

177174
return (x < y) ? x : y;
178175
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -192,11 +189,9 @@ sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
192189
}
193190

194191
if (N % 2) {
195-
oneapi::detail::Bfloat16StorageT XBits =
196-
oneapi::detail::bfloat16ToBits(x[N - 1]);
197-
oneapi::detail::Bfloat16StorageT YBits =
198-
oneapi::detail::bfloat16ToBits(y[N - 1]);
199-
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
192+
uint16_t XBits = bit_cast<uint16_t>(x[N - 1]);
193+
uint16_t YBits = bit_cast<uint16_t>(y[N - 1]);
194+
res[N - 1] = bit_cast<bfloat16>(__clc_fmin(XBits, YBits));
200195
}
201196
#else
202197
for (size_t i = 0; i < N; i++) {
@@ -237,24 +232,22 @@ template <typename T>
237232
std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmax(T x, T y) {
238233
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
239234
(__SYCL_CUDA_ARCH__ >= 800)
240-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
241-
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
242-
return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
235+
uint16_t XBits = bit_cast<uint16_t>(x);
236+
uint16_t YBits = bit_cast<uint16_t>(y);
237+
return bit_cast<bfloat16>(__clc_fmax(XBits, YBits));
243238
#else
244-
static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
239+
constexpr uint16_t CanonicalNan = 0x7FC0;
245240
if (isnan(x) && isnan(y))
246-
return oneapi::detail::bitsToBfloat16(CanonicalNan);
241+
return bit_cast<bfloat16>(CanonicalNan);
247242

248243
if (isnan(x))
249244
return y;
250245
if (isnan(y))
251246
return x;
252-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
253-
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
254-
if (((XBits | YBits) ==
255-
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
256-
!(XBits & YBits))
257-
return oneapi::detail::bitsToBfloat16(0);
247+
uint16_t XBits = bit_cast<uint16_t>(x);
248+
uint16_t YBits = bit_cast<uint16_t>(y);
249+
if (((XBits | YBits) == static_cast<uint16_t>(0x8000)) && !(XBits & YBits))
250+
return bit_cast<bfloat16, uint16_t>(0);
258251

259252
return (x > y) ? x : y;
260253
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -274,11 +267,9 @@ sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
274267
}
275268

276269
if (N % 2) {
277-
oneapi::detail::Bfloat16StorageT XBits =
278-
oneapi::detail::bfloat16ToBits(x[N - 1]);
279-
oneapi::detail::Bfloat16StorageT YBits =
280-
oneapi::detail::bfloat16ToBits(y[N - 1]);
281-
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
270+
uint16_t XBits = bit_cast<uint16_t>(x[N - 1]);
271+
uint16_t YBits = bit_cast<uint16_t>(y[N - 1]);
272+
res[N - 1] = bit_cast<bfloat16>(__clc_fmax(XBits, YBits));
282273
}
283274
#else
284275
for (size_t i = 0; i < N; i++) {
@@ -319,10 +310,10 @@ template <typename T>
319310
std::enable_if_t<std::is_same_v<T, bfloat16>, T> fma(T x, T y, T z) {
320311
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
321312
(__SYCL_CUDA_ARCH__ >= 800)
322-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
323-
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
324-
oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits(z);
325-
return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
313+
uint16_t XBits = bit_cast<uint16_t>(x);
314+
uint16_t YBits = bit_cast<uint16_t>(y);
315+
uint16_t ZBits = bit_cast<uint16_t>(z);
316+
return bit_cast<bfloat16>(__clc_fma(XBits, YBits, ZBits));
326317
#else
327318
return sycl::ext::oneapi::bfloat16{sycl::fma(float{x}, float{y}, float{z})};
328319
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -344,13 +335,10 @@ sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
344335
}
345336

346337
if (N % 2) {
347-
oneapi::detail::Bfloat16StorageT XBits =
348-
oneapi::detail::bfloat16ToBits(x[N - 1]);
349-
oneapi::detail::Bfloat16StorageT YBits =
350-
oneapi::detail::bfloat16ToBits(y[N - 1]);
351-
oneapi::detail::Bfloat16StorageT ZBits =
352-
oneapi::detail::bfloat16ToBits(z[N - 1]);
353-
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
338+
uint16_t XBits = bit_cast<uint16_t>(x[N - 1]);
339+
uint16_t YBits = bit_cast<uint16_t>(y[N - 1]);
340+
uint16_t ZBits = bit_cast<uint16_t>(z[N - 1]);
341+
res[N - 1] = bit_cast<bfloat16>(__clc_fma(XBits, YBits, ZBits));
354342
}
355343
#else
356344
for (size_t i = 0; i < N; i++) {

0 commit comments

Comments
 (0)