diff --git a/include/hip/devicelib/half/half2_math.hh b/include/hip/devicelib/half/half2_math.hh index f1da2574f..52c051b02 100644 --- a/include/hip/devicelib/half/half2_math.hh +++ b/include/hip/devicelib/half/half2_math.hh @@ -26,6 +26,15 @@ #include +extern "C++" { + +extern __device__ api_half2 rint(api_half2 x); + +} + +static inline __device__ api_half2 rint_2h(api_half2 x) { return rint(x); } + + //__device__ __half2 h2ceil ( const __half2 h ) //__device__ __half2 h2cos ( const __half2 a ) //__device__ __half2 h2exp ( const __half2 a ) diff --git a/include/hip/devicelib/half/half_math.hh b/include/hip/devicelib/half/half_math.hh index e3da2ca38..08131dfea 100644 --- a/include/hip/devicelib/half/half_math.hh +++ b/include/hip/devicelib/half/half_math.hh @@ -26,6 +26,16 @@ #include + +extern "C++" { + +extern __device__ api_half rint(api_half x); + +} + +static inline __device__ api_half rint_h(api_half x) { return rint(x); } + + //__device__ __half hceil ( const __half h ) //__device__ __half hcos ( const __half a ) //__device__ __half hexp ( const __half a ) diff --git a/include/hip/spirv_hip_fp16.h b/include/hip/spirv_hip_fp16.h index d364f1711..49ed30015 100644 --- a/include/hip/spirv_hip_fp16.h +++ b/include/hip/spirv_hip_fp16.h @@ -377,11 +377,15 @@ inline __host__ __device__ __half2 make_half2(__half x, __half y) { return __half2{x, y}; } -inline __device__ __half __low2half(__half2 x) { +inline __host__ __device__ __half +__low2half (__half2 x) +{ return __half{__half_raw{static_cast<__half2_raw>(x).data.x}}; } -inline __device__ __half __high2half(__half2 x) { +inline __host__ __device__ __half +__high2half (__half2 x) +{ return __half{__half_raw{static_cast<__half2_raw>(x).data.y}}; } @@ -660,6 +664,49 @@ inline __device__ __half2 __ldcg(const __half2* ptr) { return *ptr; } inline __device__ __half2 __ldca(const __half2* ptr) { return *ptr; } inline __device__ __half2 __ldcs(const __half2* ptr) { return *ptr; } +// Store primitives +inline __device__ void +__stg (__half *ptr, const __half value) +{ + *ptr = value; +} +inline __device__ void +__stcg (__half *ptr, const __half value) +{ + *ptr = value; +} +inline __device__ void +__stca (__half *ptr, const __half value) +{ + *ptr = value; +} +inline __device__ void +__stcs (__half *ptr, const __half value) +{ + *ptr = value; +} + +inline __device__ void +__stg (__half2 *ptr, const __half2 value) +{ + *ptr = value; +} +inline __device__ void +__stcg (__half2 *ptr, const __half2 value) +{ + *ptr = value; +} +inline __device__ void +__stca (__half2 *ptr, const __half2 value) +{ + *ptr = value; +} +inline __device__ void +__stcs (__half2 *ptr, const __half2 value) +{ + *ptr = value; +} + // Relations inline __device__ bool __heq(__half x, __half y) { return static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data; @@ -763,6 +810,11 @@ inline __device__ __half __clamp_01(__half x) { return r; } +inline __device__ __half +__habs (__half x) +{ + return __half_raw{ fabs_h (static_cast<__half_raw> (x).data) }; +} inline __device__ __half __hadd(__half x, __half y) { return __half_raw{static_cast<_hip_f16>(static_cast<__half_raw>(x).data + static_cast<__half_raw>(y).data)}; @@ -797,6 +849,11 @@ inline __device__ __half __hdiv(__half x, __half y) { static_cast<__half_raw>(y).data)}; } +inline __device__ __half2 +__habs2 (__half2 x) +{ + return __half2_raw{ fabs_2h (static_cast<__half2_raw> (x).data) }; +} inline __device__ __half2 __hadd2(__half2 x, __half2 y) { return __half2_raw{static_cast<__half2_raw>(x).data + static_cast<__half2_raw>(y).data}; @@ -846,9 +903,11 @@ inline __device__ __half hceil(__half x) { inline __device__ __half hfloor(__half x) { return __half_raw{floor_h(static_cast<__half_raw>(x).data)}; } -// inline __device__ __half hrint(__half x) { -// return __half_raw{rint_h(static_cast<__half_raw>(x).data)}; -// } +inline __device__ __half +hrint (__half x) +{ + return __half_raw{ rint_h (static_cast<__half_raw> (x).data) }; +} inline __device__ __half hsin(__half x) { return __half_raw{sin_h(static_cast<__half_raw>(x).data)}; } @@ -873,11 +932,23 @@ inline __device__ __half hlog(__half x) { inline __device__ __half hlog10(__half x) { return __half_raw{log10_h(static_cast<__half_raw>(x).data)}; } +inline __device__ __half +hmin (__half x, __half y) +{ + return __half_raw{ fmin_h (static_cast<__half_raw> (x).data, + static_cast<__half_raw> (y).data) }; +} +inline __device__ __half +hmax (__half x, __half y) +{ + return __half_raw{ fmax_h (static_cast<__half_raw> (x).data, + static_cast<__half_raw> (y).data) }; +} inline __device__ __half hrcp(__half x) { - return __half_raw{// TODO - // recip_h(static_cast<__half_raw>(x).data)}; - log10_h(static_cast<__half_raw>(x).data)}; + return __half_raw{ static_cast<_hip_f16> ( + __half_raw{ 1 }.data / static_cast<__half_raw> (x).data) }; } + inline __device__ __half hrsqrt(__half x) { return __half_raw{rsqrt_h(static_cast<__half_raw>(x).data)}; } @@ -901,8 +972,11 @@ inline __device__ __half2 h2ceil(__half2 x) { return __half2_raw{ceil_2h(x)}; } inline __device__ __half2 h2floor(__half2 x) { return __half2_raw{floor_2h(x)}; } -// inline __device__ __half2 h2rint(__half2 x) { return __half2_raw{rint_2h(x)}; -// } +inline __device__ __half2 +h2rint (__half2 x) +{ + return __half2_raw{ rint_2h (x) }; +} inline __device__ __half2 h2sin(__half2 x) { return __half2_raw{sin_2h(x)}; } inline __device__ __half2 h2cos(__half2 x) { return __half2_raw{cos_2h(x)}; } inline __device__ __half2 h2exp(__half2 x) { return __half2_raw{exp_2h(x)}; } @@ -913,9 +987,22 @@ inline __device__ __half2 h2exp10(__half2 x) { inline __device__ __half2 h2log2(__half2 x) { return __half2_raw{log2_2h(x)}; } inline __device__ __half2 h2log(__half2 x) { return log_2h(x); } inline __device__ __half2 h2log10(__half2 x) { return log10_2h(x); } -inline __device__ __half2 h2rcp(__half2 x) { return log10_2h(x); } -// TODO -//__half2 h2rcp(__half2 x) { return recip_2h(x); } +inline __device__ __half2 +hmin2 (__half2 x, __half2 y) +{ + return fmin_2h (x, y); +} +inline __device__ __half2 +hmax2 (__half2 x, __half2 y) +{ + return fmax_2h (x, y); +} +inline __device__ __half2 +h2rcp (__half2 x) +{ + return __half2_raw{ static_cast<_hip_f16_2> ( + __half2_raw{ 1 }.data / static_cast<__half2_raw> (x).data) }; +} inline __device__ __half2 h2rsqrt(__half2 x) { return rsqrt_2h(x); } inline __device__ __half2 h2sqrt(__half2 x) { return sqrt_2h(x); } inline __device__ __half2 __hisinf2(__half2 x) {