Skip to content

Commit

Permalink
Fix some half/half2 intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
franz committed Oct 26, 2022
1 parent 1385070 commit 3b184c6
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 13 deletions.
9 changes: 9 additions & 0 deletions include/hip/devicelib/half/half2_math.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@

#include <hip/devicelib/macros.hh>

extern "C++" {

extern __device__ api_half2 rint(api_half2 x);

}

static inline __device__ api_half2 rint_2h(api_half2 x) { return rint(x); }


//__device__ __half2 h2ceil ( const __half2 h )
//__device__ __half2 h2cos ( const __half2 a )
//__device__ __half2 h2exp ( const __half2 a )
Expand Down
10 changes: 10 additions & 0 deletions include/hip/devicelib/half/half_math.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@

#include <hip/devicelib/macros.hh>


extern "C++" {

extern __device__ api_half rint(api_half x);

}

static inline __device__ api_half rint_h(api_half x) { return rint(x); }


//__device__ __half hceil ( const __half h )
//__device__ __half hcos ( const __half a )
//__device__ __half hexp ( const __half a )
Expand Down
113 changes: 100 additions & 13 deletions include/hip/spirv_hip_fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,15 @@ inline __host__ __device__ __half2 make_half2(__half x, __half y) {
return __half2{x, y};
}

inline __device__ __half __low2half(__half2 x) {
inline __host__ __device__ __half
__low2half (__half2 x)
{
return __half{__half_raw{static_cast<__half2_raw>(x).data.x}};
}

inline __device__ __half __high2half(__half2 x) {
inline __host__ __device__ __half
__high2half (__half2 x)
{
return __half{__half_raw{static_cast<__half2_raw>(x).data.y}};
}

Expand Down Expand Up @@ -660,6 +664,49 @@ inline __device__ __half2 __ldcg(const __half2* ptr) { return *ptr; }
inline __device__ __half2 __ldca(const __half2* ptr) { return *ptr; }
inline __device__ __half2 __ldcs(const __half2* ptr) { return *ptr; }

// Store primitives
inline __device__ void
__stg (__half *ptr, const __half value)
{
*ptr = value;
}
inline __device__ void
__stcg (__half *ptr, const __half value)
{
*ptr = value;
}
inline __device__ void
__stca (__half *ptr, const __half value)
{
*ptr = value;
}
inline __device__ void
__stcs (__half *ptr, const __half value)
{
*ptr = value;
}

inline __device__ void
__stg (__half2 *ptr, const __half2 value)
{
*ptr = value;
}
inline __device__ void
__stcg (__half2 *ptr, const __half2 value)
{
*ptr = value;
}
inline __device__ void
__stca (__half2 *ptr, const __half2 value)
{
*ptr = value;
}
inline __device__ void
__stcs (__half2 *ptr, const __half2 value)
{
*ptr = value;
}

// Relations
inline __device__ bool __heq(__half x, __half y) {
return static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data;
Expand Down Expand Up @@ -763,6 +810,11 @@ inline __device__ __half __clamp_01(__half x) {
return r;
}

inline __device__ __half
__habs (__half x)
{
return __half_raw{ fabs_h (static_cast<__half_raw> (x).data) };
}
inline __device__ __half __hadd(__half x, __half y) {
return __half_raw{static_cast<_hip_f16>(static_cast<__half_raw>(x).data +
static_cast<__half_raw>(y).data)};
Expand Down Expand Up @@ -797,6 +849,11 @@ inline __device__ __half __hdiv(__half x, __half y) {
static_cast<__half_raw>(y).data)};
}

inline __device__ __half2
__habs2 (__half2 x)
{
return __half2_raw{ fabs_2h (static_cast<__half2_raw> (x).data) };
}
inline __device__ __half2 __hadd2(__half2 x, __half2 y) {
return __half2_raw{static_cast<__half2_raw>(x).data +
static_cast<__half2_raw>(y).data};
Expand Down Expand Up @@ -846,9 +903,11 @@ inline __device__ __half hceil(__half x) {
inline __device__ __half hfloor(__half x) {
return __half_raw{floor_h(static_cast<__half_raw>(x).data)};
}
// inline __device__ __half hrint(__half x) {
// return __half_raw{rint_h(static_cast<__half_raw>(x).data)};
// }
inline __device__ __half
hrint (__half x)
{
return __half_raw{ rint_h (static_cast<__half_raw> (x).data) };
}
inline __device__ __half hsin(__half x) {
return __half_raw{sin_h(static_cast<__half_raw>(x).data)};
}
Expand All @@ -873,11 +932,23 @@ inline __device__ __half hlog(__half x) {
inline __device__ __half hlog10(__half x) {
return __half_raw{log10_h(static_cast<__half_raw>(x).data)};
}
inline __device__ __half
hmin (__half x, __half y)
{
return __half_raw{ fmin_h (static_cast<__half_raw> (x).data,
static_cast<__half_raw> (y).data) };
}
inline __device__ __half
hmax (__half x, __half y)
{
return __half_raw{ fmax_h (static_cast<__half_raw> (x).data,
static_cast<__half_raw> (y).data) };
}
inline __device__ __half hrcp(__half x) {
return __half_raw{// TODO
// recip_h(static_cast<__half_raw>(x).data)};
log10_h(static_cast<__half_raw>(x).data)};
return __half_raw{ static_cast<_hip_f16> (
__half_raw{ 1 }.data / static_cast<__half_raw> (x).data) };
}

inline __device__ __half hrsqrt(__half x) {
return __half_raw{rsqrt_h(static_cast<__half_raw>(x).data)};
}
Expand All @@ -901,8 +972,11 @@ inline __device__ __half2 h2ceil(__half2 x) { return __half2_raw{ceil_2h(x)}; }
inline __device__ __half2 h2floor(__half2 x) {
return __half2_raw{floor_2h(x)};
}
// inline __device__ __half2 h2rint(__half2 x) { return __half2_raw{rint_2h(x)};
// }
inline __device__ __half2
h2rint (__half2 x)
{
return __half2_raw{ rint_2h (x) };
}
inline __device__ __half2 h2sin(__half2 x) { return __half2_raw{sin_2h(x)}; }
inline __device__ __half2 h2cos(__half2 x) { return __half2_raw{cos_2h(x)}; }
inline __device__ __half2 h2exp(__half2 x) { return __half2_raw{exp_2h(x)}; }
Expand All @@ -913,9 +987,22 @@ inline __device__ __half2 h2exp10(__half2 x) {
inline __device__ __half2 h2log2(__half2 x) { return __half2_raw{log2_2h(x)}; }
inline __device__ __half2 h2log(__half2 x) { return log_2h(x); }
inline __device__ __half2 h2log10(__half2 x) { return log10_2h(x); }
inline __device__ __half2 h2rcp(__half2 x) { return log10_2h(x); }
// TODO
//__half2 h2rcp(__half2 x) { return recip_2h(x); }
inline __device__ __half2
hmin2 (__half2 x, __half2 y)
{
return fmin_2h (x, y);
}
inline __device__ __half2
hmax2 (__half2 x, __half2 y)
{
return fmax_2h (x, y);
}
inline __device__ __half2
h2rcp (__half2 x)
{
return __half2_raw{ static_cast<_hip_f16_2> (
__half2_raw{ 1 }.data / static_cast<__half2_raw> (x).data) };
}
inline __device__ __half2 h2rsqrt(__half2 x) { return rsqrt_2h(x); }
inline __device__ __half2 h2sqrt(__half2 x) { return sqrt_2h(x); }
inline __device__ __half2 __hisinf2(__half2 x) {
Expand Down

0 comments on commit 3b184c6

Please sign in to comment.