从 《Beautiful Code》这本书里面看到的一章《The Quest for an Accelerated Population Count》by Henrry S.Warren, Jr.他也是《Hacker Delight》的作者,那本书里面也收集了各种计算技巧,有时间可以拿来翻翻。这篇文章讲的就是如何计算一个整数中bit=1的数量。
UPDATE: 文章最后增加了性能对比,包括了 `__builtin_popcount` 的性能。
最简单的写法是循环32次,稍微好点的做法是提前判断是否为0,但是不知道branch predication的副作用有多大。如果值范围是可以固定的话,那么最好还是使用固定循环次数的写法,这样会更加有时间保证。
uint32_t popcount11(uint32_t x) {
uint32_t ans = 0;
while (x) {
ans += x & 0x1;
x = x >> 1;
}
return ans;
}
UPDATE: 其实可以换成 `x=x&(x-1)` 这样会更快,另外一个方式是使用表查询,效果好像比这个要更好。TABLE大小是 32 * 4 = 128 字节,占用两个cache line(64字节), 在内存访问效率上应该是可以的。
/*
data = []
for i in range(0,256,8):
value = 0
for j in reversed(range(8)):
value = (value << 4) | popcount(i+j)
data.append(value)
*/
uint32_t TABLE[] = {841031952, 1127363105, 1127363105, 1413694258, 1127363105,
1413694258, 1413694258, 1700025411, 1127363105, 1413694258,
1413694258, 1700025411, 1413694258, 1700025411, 1700025411,
1986356564, 1127363105, 1413694258, 1413694258, 1700025411,
1413694258, 1700025411, 1700025411, 1986356564, 1413694258,
1700025411, 1700025411, 1986356564, 1700025411, 1986356564,
1986356564, 2272687717};
inline uint32_t GET8(unsigned char x) {
return (TABLE[x >> 3] >> ((x & 0x7) << 2)) & 0xf;
}
uint32_t popcount01(uint32_t x) {
return GET8(x & 0xff) + GET8((x >> 8) & 0xff) + GET8((x >> 16) & 0xff) +
GET8((x >> 24) & 0xff);
}
如果采用分治思想的话,那么可以写成下面这样的代码,好处是没有循环分支,并且指令数量也更少了。
uint32_t _popcount21(uint32_t x) {
x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1);
x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2);
x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4);
x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8);
x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16);
return x;
}
上面那个版本,其实和下面这个版本是等价的,好处是涉及到的常量少了,可能指令会更加精简。
uint32_t popcount21(uint32_t x) {
x = (x & 0x55555555) + ((x >> 1) & 0x55555555);
x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f);
x = (x & 0x00ff00ff) + ((x >> 8) & 0x00ff00ff);
x = (x & 0x0000ffff) + ((x >> 16) & 0x0000ffff);
return x;
}
但是如果仔细观察的话,可以发现从 `x>>4` 这里开始,其实相加就已经不会出现溢出了。因为high bits最多有4个1, low bits最多有4个1, 相加起来最多8个1, 完全可以放在4个bits里面,只不过最后我们需要在取个低位。所以上面的代码可以简化为下面这样
uint32_t __popcount21(uint32_t x) {
// 这里可以假设分别是0,1的情况
// 如果是11的话,那么11-01 = 10 = 2
// 10 - 01 = 01 = 1
// 0x 这个就是 x
x = x - ((x >> 1) & 0x55555555);
x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2);
x = (x + (x >> 4)) & 0x0f0f0f0f;
x = x + (x >> 8);
x = x + (x >> 16);
// 最后一次 low bits 最多 16, 就是 10000
// high bits 最多 16,也是 10000
// 所以最多就是 100000
return x & 0x3f;
}
上面的思想可以扩展到两个数,以及4个数,只要在合适的机会下面将两个数直接相加就好。
uint32_t popcount22(uint32_t x, uint32_t y) {
x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1);
x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2);
y = (y & 0x55555555) + ((y & 0xaaaaaaaa) >> 1);
y = (y & 0x33333333) + ((y & 0xcccccccc) >> 2);
x += y;
x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4);
x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8);
x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16);
return x;
}
uint32_t popcount24(uint32_t x, uint32_t y, uint32_t a, uint32_t b) {
x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1);
y = (y & 0x55555555) + ((y & 0xaaaaaaaa) >> 1);
a = (a & 0x55555555) + ((a & 0xaaaaaaaa) >> 1);
b = (b & 0x55555555) + ((b & 0xaaaaaaaa) >> 1);
x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2);
y = (y & 0x33333333) + ((y & 0xcccccccc) >> 2);
a = (a & 0x33333333) + ((a & 0xcccccccc) >> 2);
b = (b & 0x33333333) + ((b & 0xcccccccc) >> 2);
x += y;
a += b;
x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4);
a = (a & 0x0f0f0f0f) + ((a & 0xf0f0f0f0) >> 4);
x += a;
x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8);
x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16);
return x;
}
有了两个数的popcount求和,可以在上面做出扩展,比如求解 `pop(x) - pop(y)`, 这个式子可以变为 `pop(x) - (32 - pop(~y)) => pop(x) + pop(~y) - 32`
// pop(x) - pop(y) = pop(x) - (32 - pop(~y)) = pop(x) + pop(y) - 32
int popDiff(uint32_t x, uint32_t y) {
x = x - ((x >> 1) & 0x55555555);
x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
y = ~y;
y = y - ((y >> 1) & 0x55555555);
y = (y & 0x33333333) + ((y >> 2) & 0x33333333);
x += y;
x = (x + (x >> 4)) & 0x0f0f0f0f;
x = (x + (x >> 8));
x = (x + (x >> 16));
return x & 0x0000007f - 32;
}
此外还有个高效实现来比较较两个数的popcount,首先使用bits进行抵消,然后不断地去clear lsb, 然后看谁先为0.
int popCompare(uint32_t xp, uint32_t yp) {
unsigned x, y;
x = xp & ~yp;
y = yp & ~xp;
while (1) {
// if y == 0 then 0
// else < 0
if (x == 0) return y | -y;
if (y == 0) return 1;
x = x & (x - 1); // clear lsb
y = y & (y - 1);
}
}
还有使用avx512 vpopcount dq指令的实现,因为我的CPU不支持,所以也没有运行,不知道实现是否正确以及效果如何。
// don't use it. I don't have any cpu support avx512 vpopcnt dq.
// https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html
// g++ mm.cpp -g -W -Wall -mavx512f -mavx512vpopcntdq
uint32_t avx512_vpopcnt(const uint32_t* data, size_t size) {
uint32_t ans = 0;
uint64_t start = (uint64_t)data;
if ((start % 64) != 0) {
size_t rem = (start % 64) / 4;
start = (start + 63) / 64 * 64;
size -= rem;
FORI(i, rem) ans += popcount21(data[i]);
}
const uint8_t* ptr = (uint8_t*)start;
const uint8_t* end = ptr + size;
const size_t chunks = size / 64;
// count using AVX512 registers
__m512i accumulator = _mm512_setzero_si512();
for (size_t i = 0; i < chunks; i++, ptr += 64) {
// Note: a short chain of dependencies, likely unrolling will be needed.
const __m512i v = _mm512_loadu_si512((const __m512i*)ptr);
const __m512i p = _mm512_popcnt_epi64(v);
accumulator = _mm512_add_epi64(accumulator, p);
}
// horizontal sum of a register
uint64_t tmp[8] __attribute__((aligned(64)));
_mm512_store_si512((__m512i*)tmp, accumulator);
for (size_t i = 0; i < 8; i++) {
ans += (uint32_t)tmp[i];
}
// popcount the tail
while (ptr + 4 < end) {
ans += popcount21(*(uint32_t*)(ptr));
ptr += 4;
}
return ans;
}
下面是性能数据,代码可以看这里 这里
- level-2: `__builtin_popcount` 实现
- level-1: 打表实现
- level0: 循环移位实现
- level1,2,4: 分治算法实现
可以看到分治实现比内置实现效率还要高点
[level-2] N = 1000, took: 82ms, avg 82ns/N, ans = 443894796 [level-1] N = 1000, took: 106ms, avg 106ns/N, ans = 443894796 [level0] N = 1000, took: 337ms, avg 337ns/N, ans = 443894796 [level1] N = 1000, took: 55ms, avg 55ns/N, ans = 443894796 [level2] N = 1000, took: 37ms, avg 37ns/N, ans = 443894796 [level4] N = 1000, took: 32ms, avg 32ns/N, ans = 443894796