Skip to content

Commit

Permalink
Use SIMD for trailing bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
kimwalisch committed Jun 27, 2024
1 parent 7d06f46 commit ce11e66
Showing 1 changed file with 68 additions and 46 deletions.
114 changes: 68 additions & 46 deletions libpopcnt.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@

/* GCC compiler */
#if defined(LIBPOPCNT_X86_OR_X64) && \
LIBPOPCNT_GNUC_PREREQ(4, 9)
LIBPOPCNT_GNUC_PREREQ(5, 0)
#define LIBPOPCNT_HAVE_AVX2
#endif

Expand All @@ -100,6 +100,20 @@
#define LIBPOPCNT_HAVE_AVX512
#endif

/* Clang (Unix-like OSes) */
#if defined(LIBPOPCNT_X86_OR_X64) && !defined(_MSC_VER)
#if LIBPOPCNT_CLANG_PREREQ(3, 8) && \
__has_attribute(target) && \
(!defined(__apple_build_version__) || __apple_build_version__ >= 8000000)
#define LIBPOPCNT_HAVE_AVX2
#endif
#if LIBPOPCNT_CLANG_PREREQ(9, 0) && \
__has_attribute(target) && \
(!defined(__apple_build_version__) || __apple_build_version__ >= 8000000)
#define LIBPOPCNT_HAVE_AVX512
#endif
#endif

/* MSVC compatible compilers (Windows) */
#if defined(LIBPOPCNT_X86_OR_X64) && \
defined(_MSC_VER)
Expand Down Expand Up @@ -127,20 +141,6 @@
#endif
#endif

/* Clang (Unix-like OSes) */
#if defined(LIBPOPCNT_X86_OR_X64) && !defined(_MSC_VER)
#if LIBPOPCNT_CLANG_PREREQ(3, 8) && \
__has_attribute(target) && \
(!defined(__apple_build_version__) || __apple_build_version__ >= 8000000)
#define LIBPOPCNT_HAVE_AVX2
#endif
#if LIBPOPCNT_CLANG_PREREQ(9, 0) && \
__has_attribute(target) && \
(!defined(__apple_build_version__) || __apple_build_version__ >= 8000000)
#define LIBPOPCNT_HAVE_AVX512
#endif
#endif

/*
* Only enable CPUID runtime checks if this is really
* needed. E.g. do not enable if user has compiled
Expand All @@ -151,7 +151,11 @@
defined(_MSC_VER) || \
(LIBPOPCNT_GNUC_PREREQ(4, 2) || \
__has_builtin(__sync_val_compare_and_swap))) && \
((defined(LIBPOPCNT_HAVE_AVX512) && !(defined(__AVX512__) || (defined(__AVX512F__) && defined(__AVX512VPOPCNTDQ__)))) || \
((defined(LIBPOPCNT_HAVE_AVX512) && !(defined(__AVX512__) || \
(defined(__AVX512F__) && \
defined(__AVX512BW__) && \
defined(__AVX512VPOPCNTDQ__) && \
defined(__AVX512BITALG__)))) || \
(defined(LIBPOPCNT_HAVE_AVX2) && !defined(__AVX2__)) || \
(defined(LIBPOPCNT_HAVE_POPCNT) && !defined(__POPCNT__)))
#define LIBPOPCNT_HAVE_CPUID
Expand Down Expand Up @@ -256,12 +260,14 @@ static inline uint64_t popcnt64(uint64_t x)
/* https://en.wikipedia.org/wiki/CPUID */

/* %ebx bit flags */
#define LIBPOPCNT_BIT_AVX2 (1 << 5)
#define LIBPOPCNT_BIT_AVX512F (1 << 16)
#define LIBPOPCNT_BIT_AVX2 (1 << 5)
#define LIBPOPCNT_BIT_AVX512F (1 << 16)
#define LIBPOPCNT_BIT_AVX512BW (1 << 30)

/* %ecx bit flags */
#define LIBPOPCNT_BIT_POPCNT (1 << 23)
#define LIBPOPCNT_BIT_AVX512_BITALG (1 << 12)
#define LIBPOPCNT_BIT_AVX512_VPOPCNTDQ (1 << 14)
#define LIBPOPCNT_BIT_POPCNT (1 << 23)

/* xgetbv bit flags */
#define LIBPOPCNT_XSTATE_SSE (1 << 1)
Expand Down Expand Up @@ -351,8 +357,12 @@ static inline int get_cpuid(void)

if ((xcr0 & zmm_mask) == zmm_mask)
{
/* If all AVX512 features required by our popcnt_avx512() are supported */
/* then we add LIBPOPCNT_BIT_AVX512_VPOPCNTDQ to our CPUID flags. */
if ((abcd[1] & LIBPOPCNT_BIT_AVX512F) == LIBPOPCNT_BIT_AVX512F &&
(abcd[2] & LIBPOPCNT_BIT_AVX512_VPOPCNTDQ) == LIBPOPCNT_BIT_AVX512_VPOPCNTDQ)
(abcd[1] & LIBPOPCNT_BIT_AVX512BW) == LIBPOPCNT_BIT_AVX512BW &&
(abcd[2] & LIBPOPCNT_BIT_AVX512_VPOPCNTDQ) == LIBPOPCNT_BIT_AVX512_VPOPCNTDQ &&
(abcd[2] & LIBPOPCNT_BIT_AVX512_BITALG) == LIBPOPCNT_BIT_AVX512_BITALG)
flags |= LIBPOPCNT_BIT_AVX512_VPOPCNTDQ;
}
}
Expand Down Expand Up @@ -477,19 +487,21 @@ static inline uint64_t popcnt_avx2(const __m256i* ptr, uint64_t size)
#include <immintrin.h>

#if __has_attribute(target)
__attribute__ ((target ("avx512f,avx512vpopcntdq")))
__attribute__ ((target ("avx512f,avx512bw,avx512vpopcntdq,avx512bitalg")))
#endif
static inline uint64_t popcnt_avx512(const uint64_t* ptr, const uint64_t size)
static inline uint64_t popcnt_avx512(const uint8_t* ptr8, uint64_t size)
{
__m512i cnt = _mm512_setzero_si512();
const uint64_t* ptr64 = (const uint64_t*) ptr8;
uint64_t size64 = size / sizeof(uint64_t);
uint64_t i = 0;

for (; i + 32 <= size; i += 32)
for (; i + 32 <= size64; i += 32)
{
__m512i vec0 = _mm512_loadu_epi64(&ptr[i + 0]);
__m512i vec1 = _mm512_loadu_epi64(&ptr[i + 8]);
__m512i vec2 = _mm512_loadu_epi64(&ptr[i + 16]);
__m512i vec3 = _mm512_loadu_epi64(&ptr[i + 24]);
__m512i vec0 = _mm512_loadu_epi64(&ptr64[i + 0]);
__m512i vec1 = _mm512_loadu_epi64(&ptr64[i + 8]);
__m512i vec2 = _mm512_loadu_epi64(&ptr64[i + 16]);
__m512i vec3 = _mm512_loadu_epi64(&ptr64[i + 24]);

vec0 = _mm512_popcnt_epi64(vec0);
vec1 = _mm512_popcnt_epi64(vec1);
Expand All @@ -502,21 +514,35 @@ static inline uint64_t popcnt_avx512(const uint64_t* ptr, const uint64_t size)
cnt = _mm512_add_epi64(cnt, vec3);
}

for (; i + 8 <= size; i += 8)
for (; i + 8 <= size64; i += 8)
{
__m512i vec = _mm512_loadu_epi64(&ptr[i]);
__m512i vec = _mm512_loadu_epi64(&ptr64[i]);
vec = _mm512_popcnt_epi64(vec);
cnt = _mm512_add_epi64(cnt, vec);
}

if (i < size)
/* Process last 64 bytes */
if (i < size64)
{
__mmask8 mask = (__mmask8) (0xff >> (i + 8 - size));
__m512i vec = _mm512_maskz_loadu_epi64(mask , &ptr[i]);
__mmask8 mask = (__mmask8) (0xff >> (i + 8 - size64));
__m512i vec = _mm512_maskz_loadu_epi64(mask , &ptr64[i]);
vec = _mm512_popcnt_epi64(vec);
cnt = _mm512_add_epi64(cnt, vec);
}

uint64_t bytes = size % sizeof(uint64_t);

/* Process last 8 bytes */
if (bytes != 0)
{
i = size - bytes;
__mmask64 mask = (__mmask64) (0xff >> (i + 8 - size));
__m512i vec = _mm512_maskz_loadu_epi8(mask, &ptr8[i]);
__m512i cnt8 = _mm512_popcnt_epi8(vec);
cnt8 = _mm512_sad_epu8(cnt8, _mm512_setzero_si512());
cnt = _mm512_add_epi64(cnt, cnt8);
}

return _mm512_reduce_add_epi64(cnt);
}

Expand All @@ -530,7 +556,7 @@ static inline uint64_t popcnt_avx512(const uint64_t* ptr, const uint64_t size)
* @data: An array
* @size: Size of data in bytes
*/
static inline uint64_t popcnt(const void* data, uint64_t size)
static uint64_t popcnt(const void* data, uint64_t size)
{
/*
* CPUID runtime checks are only enabled if this is needed.
Expand Down Expand Up @@ -563,20 +589,17 @@ static inline uint64_t popcnt(const void* data, uint64_t size)

#if defined(LIBPOPCNT_HAVE_AVX512)
#if defined(__AVX512__) || \
(defined(__AVX512F__) && defined(__AVX512VPOPCNTDQ__))
(defined(__AVX512F__) && \
defined(__AVX512BW__) && \
defined(__AVX512VPOPCNTDQ__) && \
defined(__AVX512BITALG__))
/* For tiny arrays AVX512 is not worth it */
if (i + 48 <= size)
#else
if ((cpuid & LIBPOPCNT_BIT_AVX512_VPOPCNTDQ) &&
i + 48 <= size)
#endif
{
const uint64_t* ptr64 = (const uint64_t*)(ptr + i);
cnt += popcnt_avx512(ptr64, (size - i) / 8);
i = size - size % 8;
if (i == size)
return cnt;
}
return popcnt_avx512(ptr, size);
#endif

#if defined(LIBPOPCNT_HAVE_AVX2)
Expand Down Expand Up @@ -735,11 +758,10 @@ static inline uint64_t popcnt(const void* data, uint64_t size)
{
i = size - bytes;
const uint8_t* ptr8 = (const uint8_t*) data;
uint64_t val = 0;
bytes = (bytes <= 7) ? bytes : 7;
for (uint64_t j = 0; j < bytes; j++)
val |= ((uint64_t) ptr8[i + j]) << (j * 8);
cnt += popcnt64(val);
svbool_t pg8 = svwhilelt_b8(i, size);
svuint8_t vec = svld1_u8(pg8, &ptr8[i]);
svuint8_t vcnt8 = svcnt_u8_z(pg8, vec);
cnt += svaddv_u8(pg8, vcnt8);
}

return cnt;
Expand Down

0 comments on commit ce11e66

Please sign in to comment.