Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick-Nuon committed Sep 10, 2024
1 parent d2c641f commit 7999837
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 31 deletions.
2 changes: 1 addition & 1 deletion benchmark/Benchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ public unsafe void AVX2DecodingRealDataWithAllocUTF8()
RunAVX2DecodingBenchmarkWithAllocUTF8(FileContent, DecodedLengths);
}

[Benchmark]
[Benchmark]
[BenchmarkCategory("AVX512")]
public unsafe void AVX512DecodingRealDataUTF8()
{
Expand Down
37 changes: 7 additions & 30 deletions src/Base64AVX512UTF8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ namespace AVX512
{
public static partial class Base64
{

// If needed for debugging, you can do the following:
/*
static string VectorToString(Vector512<byte> vector)
Expand Down Expand Up @@ -95,7 +94,6 @@ private static unsafe UInt64 ToBase64Mask(bool base64Url, Block64* b, ref bool e
-128, -128, -128, -128, -128, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -128).AsByte();

//Vector512<byte> translated = Vector512.Permutex2var(lookup0, input, lookup1);
Vector512<byte> translated = Avx512Vbmi.PermuteVar64x8x2(lookup0, input, lookup1);
Vector512<byte> combined = Avx512F.Or(translated.AsInt64(), input.AsInt64()).AsByte();
UInt64 mask = combined.ExtractMostSignificantBits();
Expand All @@ -119,22 +117,22 @@ private static unsafe UInt64 ToBase64Mask(bool base64Url, Block64* b, ref bool e
private unsafe static ulong CompressBlock(ref Block64 b, ulong mask, byte* output, byte* tablePtr)
{
// At the time of writing .NET 9.0 does not seem to expose _mm512_maskz_compress_epi8
// directly, see this discussion:https://github.com/dotnet/runtime/discussions/100829
// directly
ulong nmask = ~mask;
var part0 = Avx512F.ExtractVector128(b.chunk0.AsByte(), 0);
var part1 = Avx512F.ExtractVector128(b.chunk0.AsByte(), 1);
var part2 = Avx512F.ExtractVector128(b.chunk0.AsByte(), 2);
var part3 = Avx512F.ExtractVector128(b.chunk0.AsByte(), 3);

Compress(part0, (ushort)mask, output, tablePtr);
Compress(part1, (ushort)(mask >> 16), output + Popcnt.X64.PopCount(nmask & 0xFFFF), tablePtr);// DEBUG: ushort vs uint32?
Compress(part1, (ushort)(mask >> 16), output + Popcnt.X64.PopCount(nmask & 0xFFFF), tablePtr);
Compress(part2, (ushort)(mask >> 32), output + Popcnt.X64.PopCount(nmask & 0xFFFFFFFF), tablePtr);
Compress(part3, (ushort)(mask >> 48), output + Popcnt.X64.PopCount(nmask & 0xFFFFFFFFFFFFUL), tablePtr);

return Popcnt.X64.PopCount(nmask);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)] // This Compress is the same as in SSE
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void Compress(Vector128<byte> data, ushort mask, byte* output, byte* tablePtr)
{
if (mask == 0)
Expand Down Expand Up @@ -173,22 +171,6 @@ private static unsafe void Compress(Vector128<byte> data, ushort mask, byte* out
Vector128<byte> answer = Ssse3.Shuffle(pruned.AsByte(), compactmask);
Sse2.Store(output, answer);
}

public static unsafe void Compress(Vector256<byte> data, uint mask, byte* output, byte* tablePtr)
{
if (mask == 0)
{
Avx2.Store(output, data);
return;
}

// Perform compression on the lower 128 bits
Compress(data.GetLower().AsByte(), (ushort)mask, output, tablePtr);

// Perform compression on the upper 128 bits, shifting output pointer by the number of 1's in the lower 16 bits of mask
int popCount = (int)Popcnt.PopCount(~mask & 0xFFFF);
Compress(Avx2.ExtractVector128(data.AsByte(), 1), (ushort)(mask >> 16), output + popCount, tablePtr);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void CopyBlock(Block64* b, byte* output)
Expand All @@ -202,6 +184,7 @@ private static unsafe void Base64Decode(byte* output, Vector512<byte> input)
{
// Perform multiply and add operations using AVX-512 instructions
Vector512<short> mergeAbAndBc = Avx512Vbmi.MultiplyAddAdjacent(input, Vector512.Create(0x01400140).AsSByte());
// Vector512<short> mergeAbAndBc = Avx512Vbmi.MultiplyAddAdjacent(input.AsInt16(), Vector512.Create(0x01400140).AsInt16());
Vector512<int> merged = Avx512BW.MultiplyAddAdjacent(mergeAbAndBc.AsInt16(), Vector512.Create(0x00011000).AsInt16());

// Define the shuffle pattern
Expand All @@ -215,6 +198,8 @@ private static unsafe void Base64Decode(byte* output, Vector512<byte> input)
Vector512<byte> shuffled = Avx512Vbmi.Shuffle(pack, merged.AsByte());

// Store the result back in the output (48 bytes)
// _mm512_mask_storeu_epi64 does not seem to be exposed yet
// See https://github.com/dotnet/runtime/discussions/100829
Avx512F.Store(output, shuffled); // Assuming 48 bytes are being written
}

Expand All @@ -230,7 +215,6 @@ private static unsafe void Base64DecodeBlock(byte* output, Block64* block)
{
Base64Decode(output, block->chunk0);
}

// Caller is responsible for checking that Avx2.IsSupported && Popcnt.IsSupported
public unsafe static OperationStatus DecodeFromBase64AVX512(ReadOnlySpan<byte> source, Span<byte> dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false)
{
Expand Down Expand Up @@ -304,7 +288,6 @@ private unsafe static OperationStatus InnerDecodeFromBase64AVX512Regular(ReadOnl
byte* srcEnd64 = srcInit + bytesToProcess - 64;
while (src <= srcEnd64)
{

Base64.Block64 b;
Base64.LoadBlock(&b, src);
src += 64;
Expand Down Expand Up @@ -348,7 +331,6 @@ private unsafe static OperationStatus InnerDecodeFromBase64AVX512Regular(ReadOnl
else
{
Base64DecodeBlock(dst, &b);

bufferBytesWritten += 48;
dst += 48;
}
Expand All @@ -361,7 +343,6 @@ private unsafe static OperationStatus InnerDecodeFromBase64AVX512Regular(ReadOnl
bufferBytesWritten += 48;
dst += 48;
}

Buffer.MemoryCopy(startOfBuffer + (blocksSize - 1) * 64, startOfBuffer, 64, 64);
bufferPtr -= (blocksSize - 1) * 64;

Expand All @@ -373,12 +354,11 @@ private unsafe static OperationStatus InnerDecodeFromBase64AVX512Regular(ReadOnl
}
// Optimization note: if this is almost full, then it is worth our
// time, otherwise, we should just decode directly.

int lastBlock = (int)((bufferPtr - startOfBuffer) % 64);
int lastBlockSrcCount = 0;
// There is at some bytes remaining beyond the last 64 bit block remaining
if (lastBlock != 0 && srcEnd - src + lastBlock >= 64) // We first check if there is any error and eliminate white spaces?:
{
int lastBlockSrcCount = 0;
while ((bufferPtr - startOfBuffer) % 64 != 0 && src < srcEnd)
{
byte val = SimdBase64.Tables.GetToBase64Value((uint)*src);
Expand All @@ -402,7 +382,6 @@ private unsafe static OperationStatus InnerDecodeFromBase64AVX512Regular(ReadOnl
src++;
lastBlockSrcCount++;
}

}

byte* subBufferPtr = startOfBuffer;
Expand All @@ -414,7 +393,6 @@ private unsafe static OperationStatus InnerDecodeFromBase64AVX512Regular(ReadOnl
{
while (subBufferPtr + 4 < bufferPtr) // we decode one base64 element (4 bit) at a time
{

UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) +
((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) +
((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) +
Expand Down Expand Up @@ -456,7 +434,6 @@ private unsafe static OperationStatus InnerDecodeFromBase64AVX512Regular(ReadOnl

if (leftover == 1)
{

bytesConsumed = (int)(src - srcInit);
bytesWritten = (int)(dst - dstInit);
return OperationStatus.NeedMoreData;
Expand Down

0 comments on commit 7999837

Please sign in to comment.