From ce11e66662e409259a477ae156b1e575301f9e6e Mon Sep 17 00:00:00 2001 From: Kim Walisch Date: Thu, 27 Jun 2024 17:39:27 +0200 Subject: [PATCH] Use SIMD for trailing bytes --- libpopcnt.h | 114 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 68 insertions(+), 46 deletions(-) diff --git a/libpopcnt.h b/libpopcnt.h index 438ae29..422566c 100644 --- a/libpopcnt.h +++ b/libpopcnt.h @@ -90,7 +90,7 @@ /* GCC compiler */ #if defined(LIBPOPCNT_X86_OR_X64) && \ - LIBPOPCNT_GNUC_PREREQ(4, 9) + LIBPOPCNT_GNUC_PREREQ(5, 0) #define LIBPOPCNT_HAVE_AVX2 #endif @@ -100,6 +100,20 @@ #define LIBPOPCNT_HAVE_AVX512 #endif +/* Clang (Unix-like OSes) */ +#if defined(LIBPOPCNT_X86_OR_X64) && !defined(_MSC_VER) + #if LIBPOPCNT_CLANG_PREREQ(3, 8) && \ + __has_attribute(target) && \ + (!defined(__apple_build_version__) || __apple_build_version__ >= 8000000) + #define LIBPOPCNT_HAVE_AVX2 + #endif + #if LIBPOPCNT_CLANG_PREREQ(9, 0) && \ + __has_attribute(target) && \ + (!defined(__apple_build_version__) || __apple_build_version__ >= 8000000) + #define LIBPOPCNT_HAVE_AVX512 + #endif +#endif + /* MSVC compatible compilers (Windows) */ #if defined(LIBPOPCNT_X86_OR_X64) && \ defined(_MSC_VER) @@ -127,20 +141,6 @@ #endif #endif -/* Clang (Unix-like OSes) */ -#if defined(LIBPOPCNT_X86_OR_X64) && !defined(_MSC_VER) - #if LIBPOPCNT_CLANG_PREREQ(3, 8) && \ - __has_attribute(target) && \ - (!defined(__apple_build_version__) || __apple_build_version__ >= 8000000) - #define LIBPOPCNT_HAVE_AVX2 - #endif - #if LIBPOPCNT_CLANG_PREREQ(9, 0) && \ - __has_attribute(target) && \ - (!defined(__apple_build_version__) || __apple_build_version__ >= 8000000) - #define LIBPOPCNT_HAVE_AVX512 - #endif -#endif - /* * Only enable CPUID runtime checks if this is really * needed. E.g. do not enable if user has compiled @@ -151,7 +151,11 @@ defined(_MSC_VER) || \ (LIBPOPCNT_GNUC_PREREQ(4, 2) || \ __has_builtin(__sync_val_compare_and_swap))) && \ - ((defined(LIBPOPCNT_HAVE_AVX512) && !(defined(__AVX512__) || (defined(__AVX512F__) && defined(__AVX512VPOPCNTDQ__)))) || \ + ((defined(LIBPOPCNT_HAVE_AVX512) && !(defined(__AVX512__) || \ + (defined(__AVX512F__) && \ + defined(__AVX512BW__) && \ + defined(__AVX512VPOPCNTDQ__) && \ + defined(__AVX512BITALG__)))) || \ (defined(LIBPOPCNT_HAVE_AVX2) && !defined(__AVX2__)) || \ (defined(LIBPOPCNT_HAVE_POPCNT) && !defined(__POPCNT__))) #define LIBPOPCNT_HAVE_CPUID @@ -256,12 +260,14 @@ static inline uint64_t popcnt64(uint64_t x) /* https://en.wikipedia.org/wiki/CPUID */ /* %ebx bit flags */ -#define LIBPOPCNT_BIT_AVX2 (1 << 5) -#define LIBPOPCNT_BIT_AVX512F (1 << 16) +#define LIBPOPCNT_BIT_AVX2 (1 << 5) +#define LIBPOPCNT_BIT_AVX512F (1 << 16) +#define LIBPOPCNT_BIT_AVX512BW (1 << 30) /* %ecx bit flags */ -#define LIBPOPCNT_BIT_POPCNT (1 << 23) +#define LIBPOPCNT_BIT_AVX512_BITALG (1 << 12) #define LIBPOPCNT_BIT_AVX512_VPOPCNTDQ (1 << 14) +#define LIBPOPCNT_BIT_POPCNT (1 << 23) /* xgetbv bit flags */ #define LIBPOPCNT_XSTATE_SSE (1 << 1) @@ -351,8 +357,12 @@ static inline int get_cpuid(void) if ((xcr0 & zmm_mask) == zmm_mask) { + /* If all AVX512 features required by our popcnt_avx512() are supported */ + /* then we add LIBPOPCNT_BIT_AVX512_VPOPCNTDQ to our CPUID flags. */ if ((abcd[1] & LIBPOPCNT_BIT_AVX512F) == LIBPOPCNT_BIT_AVX512F && - (abcd[2] & LIBPOPCNT_BIT_AVX512_VPOPCNTDQ) == LIBPOPCNT_BIT_AVX512_VPOPCNTDQ) + (abcd[1] & LIBPOPCNT_BIT_AVX512BW) == LIBPOPCNT_BIT_AVX512BW && + (abcd[2] & LIBPOPCNT_BIT_AVX512_VPOPCNTDQ) == LIBPOPCNT_BIT_AVX512_VPOPCNTDQ && + (abcd[2] & LIBPOPCNT_BIT_AVX512_BITALG) == LIBPOPCNT_BIT_AVX512_BITALG) flags |= LIBPOPCNT_BIT_AVX512_VPOPCNTDQ; } } @@ -477,19 +487,21 @@ static inline uint64_t popcnt_avx2(const __m256i* ptr, uint64_t size) #include #if __has_attribute(target) - __attribute__ ((target ("avx512f,avx512vpopcntdq"))) + __attribute__ ((target ("avx512f,avx512bw,avx512vpopcntdq,avx512bitalg"))) #endif -static inline uint64_t popcnt_avx512(const uint64_t* ptr, const uint64_t size) +static inline uint64_t popcnt_avx512(const uint8_t* ptr8, uint64_t size) { __m512i cnt = _mm512_setzero_si512(); + const uint64_t* ptr64 = (const uint64_t*) ptr8; + uint64_t size64 = size / sizeof(uint64_t); uint64_t i = 0; - for (; i + 32 <= size; i += 32) + for (; i + 32 <= size64; i += 32) { - __m512i vec0 = _mm512_loadu_epi64(&ptr[i + 0]); - __m512i vec1 = _mm512_loadu_epi64(&ptr[i + 8]); - __m512i vec2 = _mm512_loadu_epi64(&ptr[i + 16]); - __m512i vec3 = _mm512_loadu_epi64(&ptr[i + 24]); + __m512i vec0 = _mm512_loadu_epi64(&ptr64[i + 0]); + __m512i vec1 = _mm512_loadu_epi64(&ptr64[i + 8]); + __m512i vec2 = _mm512_loadu_epi64(&ptr64[i + 16]); + __m512i vec3 = _mm512_loadu_epi64(&ptr64[i + 24]); vec0 = _mm512_popcnt_epi64(vec0); vec1 = _mm512_popcnt_epi64(vec1); @@ -502,21 +514,35 @@ static inline uint64_t popcnt_avx512(const uint64_t* ptr, const uint64_t size) cnt = _mm512_add_epi64(cnt, vec3); } - for (; i + 8 <= size; i += 8) + for (; i + 8 <= size64; i += 8) { - __m512i vec = _mm512_loadu_epi64(&ptr[i]); + __m512i vec = _mm512_loadu_epi64(&ptr64[i]); vec = _mm512_popcnt_epi64(vec); cnt = _mm512_add_epi64(cnt, vec); } - if (i < size) + /* Process last 64 bytes */ + if (i < size64) { - __mmask8 mask = (__mmask8) (0xff >> (i + 8 - size)); - __m512i vec = _mm512_maskz_loadu_epi64(mask , &ptr[i]); + __mmask8 mask = (__mmask8) (0xff >> (i + 8 - size64)); + __m512i vec = _mm512_maskz_loadu_epi64(mask , &ptr64[i]); vec = _mm512_popcnt_epi64(vec); cnt = _mm512_add_epi64(cnt, vec); } + uint64_t bytes = size % sizeof(uint64_t); + + /* Process last 8 bytes */ + if (bytes != 0) + { + i = size - bytes; + __mmask64 mask = (__mmask64) (0xff >> (i + 8 - size)); + __m512i vec = _mm512_maskz_loadu_epi8(mask, &ptr8[i]); + __m512i cnt8 = _mm512_popcnt_epi8(vec); + cnt8 = _mm512_sad_epu8(cnt8, _mm512_setzero_si512()); + cnt = _mm512_add_epi64(cnt, cnt8); + } + return _mm512_reduce_add_epi64(cnt); } @@ -530,7 +556,7 @@ static inline uint64_t popcnt_avx512(const uint64_t* ptr, const uint64_t size) * @data: An array * @size: Size of data in bytes */ -static inline uint64_t popcnt(const void* data, uint64_t size) +static uint64_t popcnt(const void* data, uint64_t size) { /* * CPUID runtime checks are only enabled if this is needed. @@ -563,20 +589,17 @@ static inline uint64_t popcnt(const void* data, uint64_t size) #if defined(LIBPOPCNT_HAVE_AVX512) #if defined(__AVX512__) || \ - (defined(__AVX512F__) && defined(__AVX512VPOPCNTDQ__)) + (defined(__AVX512F__) && \ + defined(__AVX512BW__) && \ + defined(__AVX512VPOPCNTDQ__) && \ + defined(__AVX512BITALG__)) /* For tiny arrays AVX512 is not worth it */ if (i + 48 <= size) #else if ((cpuid & LIBPOPCNT_BIT_AVX512_VPOPCNTDQ) && i + 48 <= size) #endif - { - const uint64_t* ptr64 = (const uint64_t*)(ptr + i); - cnt += popcnt_avx512(ptr64, (size - i) / 8); - i = size - size % 8; - if (i == size) - return cnt; - } + return popcnt_avx512(ptr, size); #endif #if defined(LIBPOPCNT_HAVE_AVX2) @@ -735,11 +758,10 @@ static inline uint64_t popcnt(const void* data, uint64_t size) { i = size - bytes; const uint8_t* ptr8 = (const uint8_t*) data; - uint64_t val = 0; - bytes = (bytes <= 7) ? bytes : 7; - for (uint64_t j = 0; j < bytes; j++) - val |= ((uint64_t) ptr8[i + j]) << (j * 8); - cnt += popcnt64(val); + svbool_t pg8 = svwhilelt_b8(i, size); + svuint8_t vec = svld1_u8(pg8, &ptr8[i]); + svuint8_t vcnt8 = svcnt_u8_z(pg8, vec); + cnt += svaddv_u8(pg8, vcnt8); } return cnt;