Skip to content

Commit

Permalink
Speed up length on ARM
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodigrim authored and Lysxia committed Sep 14, 2024
1 parent b0b2b61 commit e4e110e
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
109 changes: 109 additions & 0 deletions cbits/aarch64/measure_off.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) 2021 Andrew Lelechenko <[email protected]>
*/

#include <string.h>
#include <sys/types.h>
#include <arm_neon.h>

/*
measure_off_naive / measure_off_neon
take a UTF-8 sequence between src and srcend, and a number of characters cnt.
If the sequence is long enough to contain cnt characters, then return how many bytes
remained unconsumed. Otherwise, if the sequence is shorter, return
negated count of lacking characters. Cf. _hs_text_measure_off below.
*/

static inline const ssize_t measure_off_naive(const uint8_t *src, const uint8_t *srcend, size_t cnt)
{
// Count leading bytes in 8 byte sequence
while (src < srcend - 7){
uint64_t w64;
memcpy(&w64, src, sizeof(uint64_t));
// find leading bytes by finding every byte that is not a continuation
// byte. The bit twiddle only results in a 0 if the original byte starts
// with 0b11...
w64 = ((w64 << 1) | ~w64) & 0x8080808080808080ULL;
// compute the popcount of w64 with two bit shifts and a multiplication
size_t leads = ( (w64 >> 7) // w64 >> 7 = Sum{0<= i <= 7} x_i * 256^i (x_i \in {0,1})
* (0x0101010101010101ULL) // 0x0101010101010101 = Sum{0<= i <= 7} 256^i
// (Sum{0<= i <= 7} x_i * 256^i) * (Sum{0<= j <= 7} 256^j)
// =(mod 256^8) (Sum{0<= k <= 7} (256^k) * (Sum {0 <= l < 7} x_l)
// as the coefficients of 256^k in the result are the x_i such that i+j =(mod 8) k
// and each i satisfies this equation for exactly one such j
// So each byte of the result contains the sum we want.
) >> 56; // bit shift to get a single byte which contains Sum {0 <= j < 7} x_j
if (cnt < leads) break;
cnt-= leads;
src+= 8;
}

// Skip until next leading byte
while (src < srcend){
uint8_t w8 = *src;
if ((int8_t)w8 >= -0x40) break;
src++;
}

// Finish up with tail
while (src < srcend && cnt > 0){
uint8_t leadByte = *src++;
cnt--;
src+= (leadByte >= 0xc0) + (leadByte >= 0xe0) + (leadByte >= 0xf0);
}

return cnt == 0 ? (ssize_t)(srcend - src) : (ssize_t)(- cnt);
}

static inline const ssize_t measure_off_neon(const uint8_t *src, const uint8_t *srcend, size_t cnt)
{
while (src < srcend - 63){
uint8x16_t w128[4] = {vld1q_u8(src), vld1q_u8(src + 16), vld1q_u8(src + 32), vld1q_u8(src + 48)};
// Which bytes are either < 128 or >= 192?
uint8x16_t mask0 = vcgtq_s8((int8x16_t)w128[0], vdupq_n_s8(0xBF));
uint8x16_t mask1 = vcgtq_s8((int8x16_t)w128[1], vdupq_n_s8(0xBF));
uint8x16_t mask2 = vcgtq_s8((int8x16_t)w128[2], vdupq_n_s8(0xBF));
uint8x16_t mask3 = vcgtq_s8((int8x16_t)w128[3], vdupq_n_s8(0xBF));

uint8x16_t mask01 = vaddq_u8(mask0, mask1);
uint8x16_t mask23 = vaddq_u8(mask2, mask3);
uint8x16_t mask = vaddq_u8(mask01, mask23);

size_t leads = (size_t)(-vaddvq_s8((int8x16_t)mask));

if (cnt < leads) break;
cnt-= leads;
src+= 64;
}

while (src < srcend - 15){
uint8x16_t w128 = vld1q_u8(src);
// Which bytes are either < 128 or >= 192?
uint8x16_t mask = vcgtq_s8((int8x16_t)w128, vdupq_n_s8(0xBF));
size_t leads = (size_t)(-vaddvq_s8((int8x16_t)mask));
if (cnt < leads) break;
cnt-= leads;
src+= 16;
}

return measure_off_naive(src, srcend, cnt);
}

/*
_hs_text_measure_off takes a UTF-8 encoded buffer, specified by (src, off, len),
and a number of code points (aka characters) cnt. If the buffer is long enough
to contain cnt characters, then _hs_text_measure_off returns a non-negative number,
measuring their size in code units (aka bytes). If the buffer is shorter,
_hs_text_measure_off returns a non-positive number, which is a negated total count
of characters available in the buffer. If len = 0 or cnt = 0, this function returns 0
as well.
This scheme allows us to implement both take/drop and length with the same C function.
The input buffer (src, off, len) must be a valid UTF-8 sequence,
this condition is not checked.
*/
ssize_t _hs_text_measure_off(const uint8_t *src, size_t off, size_t len, size_t cnt) {
ssize_t ret = measure_off_neon(src + off, src + off + len, cnt);
return ret >= 0 ? ((ssize_t)len - ret) : (- (cnt + ret));
}
5 changes: 4 additions & 1 deletion text.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,12 @@ library
cpp-options: -DPURE_HASKELL
else
c-sources: cbits/is_ascii.c
cbits/measure_off.c
cbits/reverse.c
cbits/utils.c
if (arch(aarch64))
c-sources: cbits/aarch64/measure_off.c
else
c-sources: cbits/measure_off.c

hs-source-dirs: src

Expand Down

0 comments on commit e4e110e

Please sign in to comment.