8
8
9
9
#pragma once
10
10
11
- #include < sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
12
- #include < sycl/detail/memcpy.hpp> // sycl::detail::memcpy
11
+ #include < sycl/bit_cast.hpp> // for sycl::bit_cast
12
+ #include < sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
13
+ #include < sycl/detail/memcpy.hpp> // sycl::detail::memcpy
13
14
#include < sycl/detail/vector_convert.hpp>
14
- #include < sycl/ext/oneapi/bfloat16.hpp> // for bfloat16, bfloat16ToBits
15
+ #include < sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
15
16
#include < sycl/marray.hpp> // for marray
16
17
17
18
#include < cstring> // for size_t
@@ -46,7 +47,7 @@ constexpr int num_elements_v = sycl::detail::num_elements<T>::value;
46
47
// significand has non-zero bits.
47
48
template <typename T>
48
49
std::enable_if_t <std::is_same_v<T, bfloat16>, bool > isnan (T x) {
49
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
50
+ uint16_t XBits = bit_cast< uint16_t > (x);
50
51
return (((XBits & 0x7F80 ) == 0x7F80 ) && (XBits & 0x7F )) ? true : false ;
51
52
}
52
53
@@ -90,15 +91,15 @@ template <typename T>
90
91
std::enable_if_t <std::is_same_v<T, bfloat16>, T> fabs (T x) {
91
92
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
92
93
(__SYCL_CUDA_ARCH__ >= 800 )
93
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
94
- return oneapi::detail::bitsToBfloat16 (__clc_fabs (XBits));
94
+ uint16_t XBits = bit_cast< uint16_t > (x);
95
+ return bit_cast<bfloat16> (__clc_fabs (XBits));
95
96
#else
96
97
if (!isnan (x)) {
97
- const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000 ;
98
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
98
+ constexpr uint16_t SignMask = 0x8000 ;
99
+ uint16_t XBits = bit_cast< uint16_t > (x);
99
100
x = ((XBits & SignMask) == SignMask)
100
- ? oneapi::detail::bitsToBfloat16 (XBits & ~SignMask)
101
- : x ;
101
+ ? bit_cast<bfloat16, uint16_t > (XBits & ~SignMask)
102
+ : bit_cast<bfloat16>(x) ;
102
103
}
103
104
return x;
104
105
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -116,9 +117,8 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
116
117
}
117
118
118
119
if (N % 2 ) {
119
- oneapi::detail::Bfloat16StorageT XBits =
120
- oneapi::detail::bfloat16ToBits (x[N - 1 ]);
121
- res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fabs (XBits));
120
+ uint16_t XBits = bit_cast<uint16_t >(x[N - 1 ]);
121
+ res[N - 1 ] = bit_cast<bfloat16>(__clc_fabs (XBits));
122
122
}
123
123
#else
124
124
for (size_t i = 0 ; i < N; i++) {
@@ -154,25 +154,22 @@ template <typename T>
154
154
std::enable_if_t <std::is_same_v<T, bfloat16>, T> fmin (T x, T y) {
155
155
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
156
156
(__SYCL_CUDA_ARCH__ >= 800 )
157
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
158
- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
159
- return oneapi::detail::bitsToBfloat16 (__clc_fmin (XBits, YBits));
157
+ uint16_t XBits = bit_cast< uint16_t > (x);
158
+ uint16_t YBits = bit_cast< uint16_t > (y);
159
+ return bit_cast<bfloat16> (__clc_fmin (XBits, YBits));
160
160
#else
161
- static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0 ;
161
+ constexpr uint16_t CanonicalNan = 0x7FC0 ;
162
162
if (isnan (x) && isnan (y))
163
- return oneapi::detail::bitsToBfloat16 (CanonicalNan);
163
+ return bit_cast<bfloat16> (CanonicalNan);
164
164
165
165
if (isnan (x))
166
166
return y;
167
167
if (isnan (y))
168
168
return x;
169
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
170
- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
171
- if (((XBits | YBits) ==
172
- static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 )) &&
173
- !(XBits & YBits))
174
- return oneapi::detail::bitsToBfloat16 (
175
- static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 ));
169
+ uint16_t XBits = bit_cast<uint16_t >(x);
170
+ uint16_t YBits = bit_cast<uint16_t >(y);
171
+ if (((XBits | YBits) == static_cast <uint16_t >(0x8000 )) && !(XBits & YBits))
172
+ return bit_cast<bfloat16>(static_cast <uint16_t >(0x8000 ));
176
173
177
174
return (x < y) ? x : y;
178
175
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -192,11 +189,9 @@ sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
192
189
}
193
190
194
191
if (N % 2 ) {
195
- oneapi::detail::Bfloat16StorageT XBits =
196
- oneapi::detail::bfloat16ToBits (x[N - 1 ]);
197
- oneapi::detail::Bfloat16StorageT YBits =
198
- oneapi::detail::bfloat16ToBits (y[N - 1 ]);
199
- res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fmin (XBits, YBits));
192
+ uint16_t XBits = bit_cast<uint16_t >(x[N - 1 ]);
193
+ uint16_t YBits = bit_cast<uint16_t >(y[N - 1 ]);
194
+ res[N - 1 ] = bit_cast<bfloat16>(__clc_fmin (XBits, YBits));
200
195
}
201
196
#else
202
197
for (size_t i = 0 ; i < N; i++) {
@@ -237,24 +232,22 @@ template <typename T>
237
232
std::enable_if_t <std::is_same_v<T, bfloat16>, T> fmax (T x, T y) {
238
233
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
239
234
(__SYCL_CUDA_ARCH__ >= 800 )
240
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
241
- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
242
- return oneapi::detail::bitsToBfloat16 (__clc_fmax (XBits, YBits));
235
+ uint16_t XBits = bit_cast< uint16_t > (x);
236
+ uint16_t YBits = bit_cast< uint16_t > (y);
237
+ return bit_cast<bfloat16> (__clc_fmax (XBits, YBits));
243
238
#else
244
- static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0 ;
239
+ constexpr uint16_t CanonicalNan = 0x7FC0 ;
245
240
if (isnan (x) && isnan (y))
246
- return oneapi::detail::bitsToBfloat16 (CanonicalNan);
241
+ return bit_cast<bfloat16> (CanonicalNan);
247
242
248
243
if (isnan (x))
249
244
return y;
250
245
if (isnan (y))
251
246
return x;
252
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
253
- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
254
- if (((XBits | YBits) ==
255
- static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 )) &&
256
- !(XBits & YBits))
257
- return oneapi::detail::bitsToBfloat16 (0 );
247
+ uint16_t XBits = bit_cast<uint16_t >(x);
248
+ uint16_t YBits = bit_cast<uint16_t >(y);
249
+ if (((XBits | YBits) == static_cast <uint16_t >(0x8000 )) && !(XBits & YBits))
250
+ return bit_cast<bfloat16, uint16_t >(0 );
258
251
259
252
return (x > y) ? x : y;
260
253
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -274,11 +267,9 @@ sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
274
267
}
275
268
276
269
if (N % 2 ) {
277
- oneapi::detail::Bfloat16StorageT XBits =
278
- oneapi::detail::bfloat16ToBits (x[N - 1 ]);
279
- oneapi::detail::Bfloat16StorageT YBits =
280
- oneapi::detail::bfloat16ToBits (y[N - 1 ]);
281
- res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fmax (XBits, YBits));
270
+ uint16_t XBits = bit_cast<uint16_t >(x[N - 1 ]);
271
+ uint16_t YBits = bit_cast<uint16_t >(y[N - 1 ]);
272
+ res[N - 1 ] = bit_cast<bfloat16>(__clc_fmax (XBits, YBits));
282
273
}
283
274
#else
284
275
for (size_t i = 0 ; i < N; i++) {
@@ -319,10 +310,10 @@ template <typename T>
319
310
std::enable_if_t <std::is_same_v<T, bfloat16>, T> fma (T x, T y, T z) {
320
311
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
321
312
(__SYCL_CUDA_ARCH__ >= 800 )
322
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
323
- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
324
- oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits (z);
325
- return oneapi::detail::bitsToBfloat16 (__clc_fma (XBits, YBits, ZBits));
313
+ uint16_t XBits = bit_cast< uint16_t > (x);
314
+ uint16_t YBits = bit_cast< uint16_t > (y);
315
+ uint16_t ZBits = bit_cast< uint16_t > (z);
316
+ return bit_cast<bfloat16> (__clc_fma (XBits, YBits, ZBits));
326
317
#else
327
318
return sycl::ext::oneapi::bfloat16{sycl::fma (float {x}, float {y}, float {z})};
328
319
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -344,13 +335,10 @@ sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
344
335
}
345
336
346
337
if (N % 2 ) {
347
- oneapi::detail::Bfloat16StorageT XBits =
348
- oneapi::detail::bfloat16ToBits (x[N - 1 ]);
349
- oneapi::detail::Bfloat16StorageT YBits =
350
- oneapi::detail::bfloat16ToBits (y[N - 1 ]);
351
- oneapi::detail::Bfloat16StorageT ZBits =
352
- oneapi::detail::bfloat16ToBits (z[N - 1 ]);
353
- res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fma (XBits, YBits, ZBits));
338
+ uint16_t XBits = bit_cast<uint16_t >(x[N - 1 ]);
339
+ uint16_t YBits = bit_cast<uint16_t >(y[N - 1 ]);
340
+ uint16_t ZBits = bit_cast<uint16_t >(z[N - 1 ]);
341
+ res[N - 1 ] = bit_cast<bfloat16>(__clc_fma (XBits, YBits, ZBits));
354
342
}
355
343
#else
356
344
for (size_t i = 0 ; i < N; i++) {
0 commit comments