Skip to content

Commit

Permalink
tweaking
Browse files Browse the repository at this point in the history
  • Loading branch information
lemire committed Aug 21, 2024
1 parent 9c0d239 commit fde4ad0
Showing 1 changed file with 129 additions and 81 deletions.
210 changes: 129 additions & 81 deletions src/Base64ARMUTF8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,94 +60,143 @@ private unsafe static void LoadBlock(Block64* b, char* src)
var m8 = AdvSimd.LoadVector128((ushort*)(src + 56));
// Pack 16-bit chars down to 8-bit chars, handling two vectors at a time
b->chunk0 = AdvSimd.ExtractNarrowingSaturateUpper(AdvSimd.ExtractNarrowingSaturateLower(m1.AsInt16()), m2.AsInt16()).AsByte();
b->chunk1 = AdvSimd.xtractNarrowingSaturateUpper(AdvSimd.ExtractNarrowingSaturateLower(m3.AsInt16()), m4.AsInt16()).AsByte();
b->chunk2 = AdvSimd.xtractNarrowingSaturateUpper(AdvSimd.ExtractNarrowingSaturateLower(m5.AsInt16()), m6.AsInt16()).AsByte();
b->chunk3 = AdvSimd.PxtractNarrowingSaturateUpper(AdvSimd.ExtractNarrowingSaturateLower(m7.AsInt16()), m8.AsInt16()).AsByte();
b->chunk1 = AdvSimd.ExtractNarrowingSaturateUpper(AdvSimd.ExtractNarrowingSaturateLower(m3.AsInt16()), m4.AsInt16()).AsByte();
b->chunk2 = AdvSimd.ExtractNarrowingSaturateUpper(AdvSimd.ExtractNarrowingSaturateLower(m5.AsInt16()), m6.AsInt16()).AsByte();
b->chunk3 = AdvSimd.ExtractNarrowingSaturateUpper(AdvSimd.ExtractNarrowingSaturateLower(m7.AsInt16()), m8.AsInt16()).AsByte();
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe ulong ToBase64Mask(bool base64Url, Block64* b, ref bool error)
/* [MethodImpl(MethodImplOptions.AggressiveInlining)]
{
ulong m0 = ToBase64Mask(base64Url, ref b->chunk0, ref error);
ulong m1 = ToBase64Mask(base64Url, ref b->chunk1, ref error);
ulong m2 = ToBase64Mask(base64Url, ref b->chunk2, ref error);
ulong m3 = ToBase64Mask(base64Url, ref b->chunk3, ref error);
return m0 | (m1 << 16) | (m2 << 32) | (m3 << 48);
}
}*/

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ushort ToBase64Mask(bool base64Url, ref Vector128<byte> src, ref bool error)
private static unsafe ulong ToBase64Mask(bool base64Url, Block64* b, ref bool error)
{
Vector128<sbyte> asciiSpaceTbl = Vector128.Create(
0x20, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x9, 0xa, 0x0,
0xc, 0xd, 0x0, 0x0
);

Vector128<sbyte> deltaAsso = base64Url
? Vector128.Create(0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xF, 0x0, 0xF)
: Vector128.Create(0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x00, 0x0F);

Vector128<byte> deltaValues = base64Url
? Vector128.Create(0x0, 0x0, 0x0, 0x13, 0x4, 0xBF, 0xBF, 0xB9, 0xB9, 0x0, 0x11, 0xC3, 0xBF, 0xE0, 0xB9, 0xB9)
: Vector128.Create(0x00, 0x00, 0x00, 0x13, 0x04, 0xBF, 0xBF, 0xB9, 0xB9, 0x00, 0x10, 0xC3, 0xBF, 0xBF, 0xB9, 0xB9);

Vector128<sbyte> checkAsso = base64Url
? Vector128.Create(0xD, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x3, 0x7, 0xB, 0xE, 0xB, 0x6)
: Vector128.Create(0xD, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x3, 0x7, 0xB, 0xB, 0xB, 0xF);

Vector128<byte> checkValues = base64Url
? Vector128.Create(0x80, 0x80, 0x80, 0x80, 0xCF, 0xBF, 0xB6, 0xA6, 0xB5, 0xA1, 0x0, 0x80, 0x0, 0x80, 0x0, 0x80)
: Vector128.Create(0x80, 0x80, 0x80, 0x80, 0xCF, 0xBF, 0xD5, 0xA6, 0xB5, 0x86, 0xD1, 0x80, 0xB1, 0x80, 0x91, 0x80);

Vector128<Int32> shifted = Sse2.ShiftRightLogical(src.AsInt32(), 3);

Vector128<byte> deltaHash = Sse2.Average(Ssse3.Shuffle(deltaAsso,
src.AsSByte()).
AsByte(),
shifted.AsByte());
Vector128<byte> checkHash = Sse2.Average(Ssse3.Shuffle(checkAsso,
src.AsSByte()).
AsByte(),
shifted.AsByte());


Vector128<sbyte> outVector = Sse2.AddSaturate(Ssse3.Shuffle(deltaValues.AsByte(), deltaHash).AsSByte(),
src.AsSByte());
Vector128<sbyte> chkVector = Sse2.AddSaturate(Ssse3.Shuffle(checkValues.AsByte(), checkHash).AsSByte(),
src.AsSByte());

int mask = Sse2.MoveMask(chkVector.AsByte());
if (mask != 0)
// Vector of 0xf for masking lower nibbles
Vector128<byte> v0f = Vector128.Create((byte)0xf);

Vector128<byte> underscore0 = Vector128<byte>.Zero;
Vector128<byte> underscore1 = Vector128<byte>.Zero;
Vector128<byte> underscore2 = Vector128<byte>.Zero;
Vector128<byte> underscore3 = Vector128<byte>.Zero;

if (base64Url)
{
Vector128<byte> asciiSpace = Sse2.CompareEqual(Ssse3.Shuffle(asciiSpaceTbl.AsByte(), src), src);
error |= (mask != Sse2.MoveMask(asciiSpace));
underscore0 = Vector128.Equals(b->chunk0, Vector128.Create((byte)0x5f));
underscore1 = Vector128.Equals(b->chunk1, Vector128.Create((byte)0x5f));
underscore2 = Vector128.Equals(b->chunk2, Vector128.Create((byte)0x5f));
underscore3 = Vector128.Equals(b->chunk3, Vector128.Create((byte)0x5f));
}

// Extract lower nibbles
Vector128<byte> loNibbles0 = b->chunk0 & v0f;
Vector128<byte> loNibbles1 = b->chunk1 & v0f;
Vector128<byte> loNibbles2 = b->chunk2 & v0f;
Vector128<byte> loNibbles3 = b->chunk2 & v0f;

// Extract higher nibbles
Vector128<byte> hiNibbles0 = ArmBase.ShiftRightLogical(b->chunk0, 4);
Vector128<byte> hiNibbles1 = ArmBase.ShiftRightLogical(b->chunk1, 4);
Vector128<byte> hiNibbles2 = ArmBase.ShiftRightLogical(b->chunk2, 4);
Vector128<byte> hiNibbles3 = ArmBase.ShiftRightLogical(b->chunk3, 4);

// Lookup tables for encoding
Vector128<byte> lutLo = base64Url
? Vector128.Create((byte)0x3A, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x61, 0xE1, 0xF4, 0xE5, 0xA5, 0xF4, 0xF4)
: Vector128.Create((byte)0x3A, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x61, 0xE1, 0xB4, 0xE5, 0xE5, 0xF4, 0xB4);

Vector128<byte> lutHi = base64Url
? Vector128.Create((byte)0x11, 0x20, 0x42, 0x80, 0x8, 0x4, 0x8, 0x4,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20)
: Vector128.Create((byte)0x11, 0x20, 0x42, 0x80, 0x8, 0x4, 0x8, 0x4,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20);
// Lookup for lower and higher nibbles
Vector128<byte> lo0 = ArmBase.LookupVector128(lutLo, loNibbles0);
Vector128<byte> hi0 = ArmBase.LookupVector128(lutHi, hiNibbles0);
Vector128<byte> lo1 = ArmBase.LookupVector128(lutLo, loNibbles1);
Vector128<byte> hi1 = ArmBase.LookupVector128(lutHi, hiNibbles1);
Vector128<byte> lo2 = ArmBase.LookupVector128(lutLo, loNibbles2);
Vector128<byte> hi2 = ArmBase.LookupVector128(lutHi, hiNibbles2);
Vector128<byte> lo3 = ArmBase.LookupVector128(lutLo, loNibbles3);
Vector128<byte> hi3 = ArmBase.LookupVector128(lutHi, hiNibbles3);
if (base64Url)
{
hi0 = ArmBase.BitwiseClear(hi0, underscore0);
hi1 = ArmBase.BitwiseClear(hi1, underscore1);
hi2 = ArmBase.BitwiseClear(hi2, underscore2);
hi3 = ArmBase.BitwiseClear(hi3, underscore3);
}

src = outVector.AsByte();
return (ushort)mask;
// Check for invalid characters
Vector128<byte> checks = ArmBase.MaxAcross(hi0 | hi1 | hi2 | hi3);

error = (checks.ToScalar() > 0x3);

ushort badCharmask = 0;
if (error)
{
Vector128<byte> test0 = AdvSimd.CompareTest(lo0, hi0);
Vector128<byte> test1 = AdvSimd.CompareTest(lo1, hi1);
Vector128<byte> test2 = AdvSimd.CompareTest(lo2, hi2);
Vector128<byte> test3 = AdvSimd.CompareTest(lo1, hi3);
Vector128<byte> bit_mask = Vector128.Create((byte)0x01, 0x02, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
0x01, 0x02, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80);
Vector128<byte> sum0 = AdvSimd.Arm64.AddPairwise(test0 & bit_mask, test1 & bit_mask);
Vector128<byte> sum1 = AdvSimd.Arm64.AddPairwise(test2 & bit_mask, test3 & bit_mask);
sum0 = AdvSimd.Arm64.AddPairwise(sum0, sum1);
sum0 = AdvSimd.Arm64.AddPairwise(sum0, sum0);
badcharmask = sum0.AsUInt64().ToScalar();
}

Vector128<byte> roll_lut = base64Url
? Vector128.Create((byte)0xe0, 0x11, 0x13, 0x4, 0xbf, 0xbf, 0xb9, 0xb9,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0)
: Vector128.Create((byte)0x0, 0x10, 0x13, 0x4, 0xbf, 0xbf, 0xb9, 0xb9,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0);
Vector128<byte> vsecond_last = base64Url
? Vector128.Create((byte)0x2d)
: Vector128.Create((byte)0x2f);
if (base64_url) {
hiNibbles0 = ArmBase.BitwiseClear(hiNibbles0, underscore0);
hiNibbles1 = ArmBase.BitwiseClear(hiNibbles1, underscore1);
hiNibbles2 = ArmBase.BitwiseClear(hiNibbles2, underscore2);
hiNibbles3 = ArmBase.BitwiseClear(hiNibbles3, underscore3);
}
Vector128<byte> roll0 = ArmBase.LookupVector128(roll_lut, (b->chunks[0] == vsecond_last) + hiNibbles0);
Vector128<byte> roll1 = ArmBase.LookupVector128(roll_lut, (b->chunks[1] == vsecond_last) + hiNibbles1);
Vector128<byte> roll2 = ArmBase.LookupVector128(roll_lut, (b->chunks[2] == vsecond_last) + hiNibbles2);
Vector128<byte> roll3 = ArmBase.LookupVector128(roll_lut, (b->chunks[3] == vsecond_last) + hiNibbles3);
b->chunks[0] += roll0;
b->chunks[1] += roll1;
b->chunks[2] += roll2;
b->chunks[3] += roll3;
return badcharmask;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe static ulong CompressBlock(ref Block64 b, ulong mask, byte* output)
{
ulong nmask = ~mask;
Compress(b.chunk0, (ushort)mask, output);
Compress(b.chunk1, (ushort)(mask >> 16), output + Popcnt.X64.PopCount(nmask & 0xFFFF));
Compress(b.chunk2, (ushort)(mask >> 32), output + Popcnt.X64.PopCount(nmask & 0xFFFFFFFF));
Compress(b.chunk3, (ushort)(mask >> 48), output + Popcnt.X64.PopCount(nmask & 0xFFFFFFFFFFFFUL));
Compress(b.chunk1, (ushort)(mask >> 16), output + UInt64.PopCount(nmask & 0xFFFF));
Compress(b.chunk2, (ushort)(mask >> 32), output + UInt64.PopCount(nmask & 0xFFFFFFFF));
Compress(b.chunk3, (ushort)(mask >> 48), output + UInt64.PopCount(nmask & 0xFFFFFFFFFFFFUL));

return Popcnt.X64.PopCount(nmask);
return UInt64.PopCount(nmask);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void Compress(Vector128<byte> data, ushort mask, byte* output)
{
if (mask == 0)
{
Sse2.Store(output, data);
Vector128.Store(data, output);
return;
}

Expand All @@ -165,7 +214,7 @@ private static unsafe void Compress(Vector128<byte> data, ushort mask, byte* out
Vector128<sbyte> shufmask = Vector128.Create(value2, value1).AsSByte();

// Increment by 0x08 the second half of the mask
shufmask = Sse2.Add(shufmask, Vector128.Create(0x08080808, 0x08080808, 0, 0).AsSByte());
shufmask = shufmask + Vector128.Create(0x08080808, 0x08080808, 0, 0).AsSByte();

// this is the version "nearly pruned"
Vector128<sbyte> pruned = Ssse3.Shuffle(data.AsSByte(), shufmask);
Expand All @@ -179,21 +228,20 @@ private static unsafe void Compress(Vector128<byte> data, ushort mask, byte* out

fixed (byte* tablePtr = Tables.pshufbCombineTable)
{
Vector128<byte> compactmask = Sse2.LoadVector128(tablePtr + pop1 * 8);
Vector128<byte> compactmask = Vector128.Load(tablePtr + pop1 * 8);

Vector128<byte> answer = Ssse3.Shuffle(pruned.AsByte(), compactmask);
Sse2.Store(output, answer);
Vector128.Store(answer, output);
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void CopyBlock(Block64* b, byte* output)
{
// Directly store each 128-bit chunk to the output buffer using SSE2
Sse2.Store(output, b->chunk0);
Sse2.Store(output + 16, b->chunk1);
Sse2.Store(output + 32, b->chunk2);
Sse2.Store(output + 48, b->chunk3);
Vector128.Store(b->chunk0, output);
Vector128.Store(b->chunk1, output + 16);
Vector128.Store(b->chunk2, output + 32);
Vector128.Store(b->chunk3, output + 48);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down Expand Up @@ -231,16 +279,16 @@ private unsafe static void Base64Decode(byte* output, Vector128<byte> input)
Vector128<byte> t2 = Ssse3.Shuffle(t1.AsSByte(), packShuffle).AsByte();

// Store the output. This writes 16 bytes, but we only need 12.
Sse2.Store(output, t2);
Vector128.Store(t2, output);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void Base64DecodeBlock(byte* outPtr, byte* srcPtr)
{
Base64Decode(outPtr, Sse2.LoadVector128(srcPtr));
Base64Decode(outPtr + 12, Sse2.LoadVector128(srcPtr + 16));
Base64Decode(outPtr + 24, Sse2.LoadVector128(srcPtr + 32));
Base64Decode(outPtr + 36, Sse2.LoadVector128(srcPtr + 48));
Base64Decode(outPtr, Vector128.Load(srcPtr));
Base64Decode(outPtr + 12, Vector128.Load(srcPtr + 16));
Base64Decode(outPtr + 24, Vector128.Load(srcPtr + 32));
Base64Decode(outPtr + 36, Vector128.Load(srcPtr + 48));
}

// Function to decode a Base64 block into binary data.
Expand All @@ -256,10 +304,10 @@ private static unsafe void Base64DecodeBlock(byte* output, Block64* block)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void Base64DecodeBlockSafe(byte* outPtr, byte* srcPtr)
{
Base64Decode(outPtr, Sse2.LoadVector128(srcPtr));
Base64Decode(outPtr + 12, Sse2.LoadVector128(srcPtr + 16));
Base64Decode(outPtr + 24, Sse2.LoadVector128(srcPtr + 32));
Vector128<byte> tempBlock = Sse2.LoadVector128(srcPtr + 48);
Base64Decode(outPtr, Vector128.Load(srcPtr));
Base64Decode(outPtr + 12, Vector128.Load(srcPtr + 16));
Base64Decode(outPtr + 24, Vector128.Load(srcPtr + 32));
Vector128<byte> tempBlock = Vector128.Load(srcPtr + 48);
byte[] buffer = new byte[16];
fixed (byte* bufferPtr = buffer)
{
Expand All @@ -271,20 +319,20 @@ private static unsafe void Base64DecodeBlockSafe(byte* outPtr, byte* srcPtr)
}
}

// Caller is responsible for checking that Ssse3.IsSupported && Popcnt.IsSupported
public unsafe static OperationStatus DecodeFromBase64SSE(ReadOnlySpan<byte> source, Span<byte> dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false)
// Caller is responsible for checking tha (AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian)
public unsafe static OperationStatus DecodeFromBase64ARM(ReadOnlySpan<byte> source, Span<byte> dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false)
{
if (isUrl)
{
return InnerDecodeFromBase64SSEUrl(source, dest, out bytesConsumed, out bytesWritten);
return InnerDecodeFromBase64ARMUrl(source, dest, out bytesConsumed, out bytesWritten);
}
else
{
return InnerDecodeFromBase64SSERegular(source, dest, out bytesConsumed, out bytesWritten);
return InnerDecodeFromBase64ARMRegular(source, dest, out bytesConsumed, out bytesWritten);
}
}

private unsafe static OperationStatus InnerDecodeFromBase64SSERegular(ReadOnlySpan<byte> source, Span<byte> dest, out int bytesConsumed, out int bytesWritten)
private unsafe static OperationStatus InnerDecodeFromBase64ARMRegular(ReadOnlySpan<byte> source, Span<byte> dest, out int bytesConsumed, out int bytesWritten)
{
// translation from ASCII to 6 bit values
bool isUrl = false;
Expand Down Expand Up @@ -624,7 +672,7 @@ private unsafe static OperationStatus InnerDecodeFromBase64SSERegular(ReadOnlySp
}
}

private unsafe static OperationStatus InnerDecodeFromBase64SSEUrl(ReadOnlySpan<byte> source, Span<byte> dest, out int bytesConsumed, out int bytesWritten)
private unsafe static OperationStatus InnerDecodeFromBase64ARMUrl(ReadOnlySpan<byte> source, Span<byte> dest, out int bytesConsumed, out int bytesWritten)
{
// translation from ASCII to 6 bit values
bool isUrl = true;
Expand Down

0 comments on commit fde4ad0

Please sign in to comment.