diff --git a/Source/SimdImpl_AVX2FMA.cpp b/Source/SimdImpl_AVX2FMA.cpp index 405327f..7242778 100644 --- a/Source/SimdImpl_AVX2FMA.cpp +++ b/Source/SimdImpl_AVX2FMA.cpp @@ -57,6 +57,24 @@ inline __m256d pow2n(const __m256d n) { __m256d d = _mm256_castsi256_pd(c); // bit-cast back to double return d; } + +inline __m256d abs_pd_avx(__m256d xIn) +{ + unsigned int data[2] = {0xFFFFFFFFu, 0x7FFFFFFFu}; + __m256d mask = _mm256_broadcast_sd((double*)data); + return _mm256_and_pd(xIn, mask); +} + +inline __m256i finite_mask_avx(__m256d x) +{ + __m256i i = _mm256_castpd_si256(x); + __m256i iShift = _mm256_sll_epi64(i, _mm_cvtsi64_si128(1)); + __m256i exp_val = _mm256_set1_epi64x(0xFFE0000000000000); + __m256i result = ~_mm256_cmpeq_epi64(_mm256_and_si256(iShift, exp_val), exp_val); + return result; +} + + // NOTE(cmo): AVX impl of exp_pd, based on Agner Fog's vector class // https://github.com/vectorclass/version2/blob/master/vectormath_exp.h // The implementation here, based on a classic Taylor series, rather than a @@ -98,7 +116,24 @@ inline __m256d exp_pd_avx(__m256d xIn) __m256d n2 = pow2n(r); z = _mm256_mul_pd(_mm256_add_pd(z, _mm256_set1_pd(1.0)), n2); - // TODO(cmo): Probably should have some of the nan/inf error handling code. + // NOTE(cmo): Error/edge-case handling code. The previous warning was prophetic. + // abs(xIn) < xMax + __m256i mask1 = _mm256_castpd_si256(_mm256_cmp_pd(abs_pd_avx(xIn), _mm256_set1_pd(xMax), 1)); + __m256i mask2 = finite_mask_avx(xIn); + __m256i mask = mask1 & mask2; + // if all mask is set, then exit normally. + if (_mm256_testc_si256(mask, _mm256_set1_epi64x(-1)) != 0) + return z; + + __m256d maskd = _mm256_castsi256_pd(mask); + __m256d inputSign = _mm256_and_pd(xIn, _mm256_set1_pd(-0.0)); + __m256d inf = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7FF0000000000000)); + r = _mm256_blendv_pd(inf, _mm256_set1_pd(0.0), inputSign); // values for over/underflow/inf + z = _mm256_blendv_pd(r, z, maskd); // +/- underflow + + __m256d nan_mask = _mm256_cmp_pd(xIn, xIn, 3); // check for unordered comparison, i.e. a value is nan + z = _mm256_blendv_pd(z, xIn, nan_mask); // set output to nan if input is nan + return z; } diff --git a/Source/SimdImpl_AVX512.cpp b/Source/SimdImpl_AVX512.cpp index f43f67b..0474bea 100644 --- a/Source/SimdImpl_AVX512.cpp +++ b/Source/SimdImpl_AVX512.cpp @@ -57,6 +57,23 @@ inline __m512d pow2n(const __m512d n) { __m512d d = _mm512_castsi512_pd(c); // bit-cast back to double return d; } + +inline __m512d abs_pd_avx512(__m512d xIn) +{ + unsigned int data[2] = {0xFFFFFFFFu, 0x7FFFFFFFu}; + __m512d mask = _mm512_broadcast_f64x4(_mm256_broadcast_sd(((double*)data))); + return _mm512_and_pd(xIn, mask); +} + +inline __mmask8 finite_mask_avx512(__m512d x) +{ + __m512i i = _mm512_castpd_si512(x); + __m512i iShift = _mm512_sll_epi64(i, _mm_cvtsi64_si128(1)); + __m512i exp_val = _mm512_set1_epi64(0xFFE0000000000000); + __mmask8 result = ~_mm512_cmpeq_epi64_mask(_mm512_and_si512(iShift, exp_val), exp_val); + return result; +} + // NOTE(cmo): AVX impl of exp_pd, based on Agner Fog's vector class // https://github.com/vectorclass/version2/blob/master/vectormath_exp.h // The implementation here, based on a classic Taylor series, rather than a @@ -98,7 +115,23 @@ inline __m512d exp_pd_avx512(__m512d xIn) __m512d n2 = pow2n(r); z = _mm512_mul_pd(_mm512_add_pd(z, _mm512_set1_pd(1.0)), n2); - // TODO(cmo): Probably should have some of the nan/inf error handling code. + // NOTE(cmo): Error/edge-case handling code. The previous warning was prophetic. + // abs(xIn) < xMax + __mmask8 mask1 = _mm512_cmp_pd_mask(abs_pd_avx512(xIn), _mm512_set1_pd(xMax), 1); + __mmask8 mask2 = finite_mask_avx512(xIn); + __mmask8 mask = mask1 & mask2; + // if all mask is set, then exit normally. + if (mask == 255) + return z; + + // __m256d maskd = _mm256_castsi256_pd(mask); + __mmask8 inputNegative = _mm512_cmp_pd_mask(xIn, _mm512_set1_pd(0.0), 1); + __m512d inf = _mm512_castsi512_pd(_mm512_set1_epi64(0x7FF0000000000000)); + r = _mm512_mask_blend_pd(inputNegative, inf, _mm512_set1_pd(0.0)); // values for over/underflow/inf + z = _mm512_mask_blend_pd(mask, r, z); // +/- underflow + + __mmask8 nan_mask = _mm512_cmp_pd_mask(xIn, xIn, 3); // check for unordered comparison, i.e. a value is nan + z = _mm512_mask_blend_pd(nan_mask, z, xIn); // set output to nan if input is nan return z; }