Skip to content

Commit

Permalink
ShuffleI8 improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
macaba committed Jan 15, 2025
1 parent 9000c07 commit 1bb3084
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 86 deletions.
42 changes: 0 additions & 42 deletions source/TS.NET.Tests/ShuffleI8Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,6 @@ namespace TS.NET.Tests
{
public class ShuffleI8Tests
{
[Fact]
public void ShuffleI8_FourChannels_Samples64()
{
const int length = 64;
ReadOnlySpan<sbyte> input = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4];
Span<sbyte> output = new sbyte[length];

ShuffleI8.FourChannels(input, output);

Span<sbyte> expectedOutput = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4];

for (int i = 0; i < length; i++)
{
Assert.Equal(expectedOutput[i], output[i]);
}
}

[Fact]
public void ShuffleI8_FourChannels_Samples64_Alt()
{
const int length = 64;
Span<sbyte> input = new sbyte[length];
Span<sbyte> output = new sbyte[length];

int n = 0;
for(int i = 0; i < length; i+=4)
{
input[i] = (sbyte)n;
input[i + 1] = (sbyte)(n + 16);
input[i + 2] = (sbyte)(n + 32);
input[i + 3] = (sbyte)(n + 48);
n++;
}

ShuffleI8.FourChannels(input, output);

for (int i = 0; i < length; i++)
{
Assert.Equal(i, output[i]);
}
}

[Fact]
public void ShuffleI8_FourChannels_Samples128()
{
Expand Down
97 changes: 53 additions & 44 deletions source/TS.NET/Processing/ShuffleI8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ public static void FourChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)

if (Avx2.IsSupported) // Const after JIT/AOT
{
var processingLength = Vector256<sbyte>.Count * 2; // 64
var processingLength = Vector256<sbyte>.Count * 4; // 128
if (input.Length % processingLength != 0) throw new ArgumentException($"Input length must be multiple of {processingLength}");

int ch2Offset_64 = channelBlockSizeBytes / 8;
int ch3Offset_64 = (channelBlockSizeBytes * 2) / 8;
int ch4Offset_64 = (channelBlockSizeBytes * 3) / 8;
int ch2Offset = channelBlockSizeBytes;
int ch3Offset = channelBlockSizeBytes * 2;
int ch4Offset = channelBlockSizeBytes * 3;
Vector256<sbyte> shuffleMask = Vector256.Create(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15).AsSByte();
Vector256<int> permuteMask = Vector256.Create(0, 4, 1, 5, 2, 6, 3, 7);
unsafe
Expand All @@ -29,29 +29,36 @@ public static void FourChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)
fixed (sbyte* outputP = output)
{
sbyte* inputPtr = inputP;
ulong* outputPtr_64 = (ulong*)outputP;
sbyte* outputPtr = outputP;
sbyte* finishPtr = inputP + input.Length;
while (inputPtr < finishPtr)
{
// Note: x2 unroll seems to be the sweet spot in benchmarks, allowing for 128 bit stores
var loaded1 = Avx.LoadVector256(inputPtr);
var loaded2 = Avx.LoadVector256(inputPtr + Vector256<sbyte>.Count);
var shuffled1 = Avx2.Shuffle(loaded1, shuffleMask); // shuffled1 = <1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4>
var loaded1 = Vector256.Load(inputPtr);
var loaded2 = Vector256.Load(inputPtr + Vector256<sbyte>.Count);
var loaded3 = Vector256.Load(inputPtr + Vector256<sbyte>.Count * 2);
var loaded4 = Vector256.Load(inputPtr + Vector256<sbyte>.Count * 3);
var shuffled1 = Avx2.Shuffle(loaded1, shuffleMask);
var shuffled2 = Avx2.Shuffle(loaded2, shuffleMask);
var permuted1 = Avx2.PermuteVar8x32(shuffled1.AsInt32(), permuteMask); // permuted1 = <1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4>
var permuted2 = Avx2.PermuteVar8x32(shuffled2.AsInt32(), permuteMask);
var permuted1_64 = permuted1.AsUInt64();
var permuted2_64 = permuted2.AsUInt64();
var unpackHigh = Avx2.UnpackHigh(permuted1_64, permuted2_64); // unpackHigh = <2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4>
var unpackLow = Avx2.UnpackLow(permuted1_64, permuted2_64); // unpackLow = <1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3>

Vector128.Store(unpackLow.GetLower(), outputPtr_64);
Vector128.Store(unpackHigh.GetLower(), outputPtr_64 + ch2Offset_64);
Vector128.Store(unpackLow.GetUpper(), outputPtr_64 + ch3Offset_64);
Vector128.Store(unpackHigh.GetUpper(), outputPtr_64 + ch4Offset_64);

var shuffled3 = Avx2.Shuffle(loaded3, shuffleMask);
var shuffled4 = Avx2.Shuffle(loaded4, shuffleMask);
var permuted1 = Avx2.PermuteVar8x32(shuffled1.AsInt32(), permuteMask).AsUInt64();
var permuted2 = Avx2.PermuteVar8x32(shuffled2.AsInt32(), permuteMask).AsUInt64();
var permuted3 = Avx2.PermuteVar8x32(shuffled3.AsInt32(), permuteMask).AsUInt64();
var permuted4 = Avx2.PermuteVar8x32(shuffled4.AsInt32(), permuteMask).AsUInt64();
var unpackLow = Avx2.UnpackLow(permuted1, permuted2);
var unpackLow2 = Avx2.UnpackLow(permuted3, permuted4);
var channel1 = Avx2.Permute2x128(unpackLow, unpackLow2, 0x20).AsSByte();
var channel3 = Avx2.Permute2x128(unpackLow, unpackLow2, 0x31).AsSByte();
var unpackHigh = Avx2.UnpackHigh(permuted1, permuted2);
var unpackHigh2 = Avx2.UnpackHigh(permuted3, permuted4);
var channel2 = Avx2.Permute2x128(unpackHigh, unpackHigh2, 0x20).AsSByte();
var channel4 = Avx2.Permute2x128(unpackHigh, unpackHigh2, 0x31).AsSByte();
Vector256.Store(channel1, outputPtr);
Vector256.Store(channel2, outputPtr + ch2Offset);
Vector256.Store(channel3, outputPtr + ch3Offset);
Vector256.Store(channel4, outputPtr + ch4Offset);
inputPtr += processingLength;
outputPtr_64 += 2;
outputPtr += Vector256<sbyte>.Count;
}
}
}
Expand All @@ -75,10 +82,10 @@ public static void FourChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)
sbyte* finishPtr = inputP + input.Length;
while (inputPtr < finishPtr)
{
var loaded1 = Sse2.LoadVector128(inputPtr);
var loaded2 = Sse2.LoadVector128(inputPtr + Vector128<sbyte>.Count);
var loaded3 = Sse2.LoadVector128(inputPtr + Vector128<sbyte>.Count * 2);
var loaded4 = Sse2.LoadVector128(inputPtr + Vector128<sbyte>.Count * 3);
var loaded1 = Vector128.Load(inputPtr);
var loaded2 = Vector128.Load(inputPtr + Vector128<sbyte>.Count);
var loaded3 = Vector128.Load(inputPtr + Vector128<sbyte>.Count * 2);
var loaded4 = Vector128.Load(inputPtr + Vector128<sbyte>.Count * 3);
var shuffled1 = Ssse3.Shuffle(loaded1, shuffleMask);
var shuffled2 = Ssse3.Shuffle(loaded2, shuffleMask);
var shuffled3 = Ssse3.Shuffle(loaded3, shuffleMask);
Expand All @@ -96,7 +103,7 @@ public static void FourChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)
Vector128.Store(channel3, outputPtr + ch3Offset);
Vector128.Store(channel4, outputPtr + ch4Offset);
inputPtr += processingLength;
outputPtr += 16;
outputPtr += Vector128<sbyte>.Count;
}
}
}
Expand Down Expand Up @@ -130,9 +137,8 @@ public static void FourChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)
AdvSimd.Store(outputPtr + ch2Offset, loaded.Value2);
AdvSimd.Store(outputPtr + ch3Offset, loaded.Value3);
AdvSimd.Store(outputPtr + ch4Offset, loaded.Value4);

inputPtr += processingLength;
outputPtr += 16;
outputPtr += Vector128<sbyte>.Count;
}
}
}
Expand Down Expand Up @@ -177,11 +183,11 @@ public static void TwoChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)

if (Avx2.IsSupported) // Const after JIT/AOT
{
var processingLength = Vector256<sbyte>.Count; // 32
var processingLength = Vector256<sbyte>.Count * 2; // 64
if (input.Length % processingLength != 0)
throw new ArgumentException($"Input length must be multiple of {processingLength}");

int ch2Offset_64 = channelBlockSizeBytes / 8;
int ch2Offset = channelBlockSizeBytes;
Vector256<sbyte> shuffleMask = Vector256.Create(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15).AsSByte();
Vector256<int> permuteMask = Vector256.Create(0, 1, 4, 5, 2, 3, 6, 7);
unsafe
Expand All @@ -190,19 +196,22 @@ public static void TwoChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)
fixed (sbyte* outputP = output)
{
sbyte* inputPtr = inputP;
ulong* outputPtr_64 = (ulong*)outputP;
sbyte* outputPtr = outputP;
sbyte* finishPtr = inputP + input.Length;
while (inputPtr < finishPtr)
{
var shuffled1 = Avx2.Shuffle(Avx.LoadVector256(inputPtr), shuffleMask); // shuffled1 = <1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2>
var permuted1 = Avx2.PermuteVar8x32(shuffled1.AsInt32(), permuteMask); // permuted1 = <16843009, 16843009, 16843009, 16843009, 33686018, 33686018, 33686018, 33686018>
var permuted1_64 = permuted1.AsUInt64();
outputPtr_64[0] = permuted1_64[0];
outputPtr_64[1] = permuted1_64[1];
outputPtr_64[0 + ch2Offset_64] = permuted1_64[2];
outputPtr_64[1 + ch2Offset_64] = permuted1_64[3];
var loaded1 = Vector256.Load(inputPtr);
var loaded2 = Vector256.Load(inputPtr + Vector256<sbyte>.Count);
var shuffled1 = Avx2.Shuffle(loaded1, shuffleMask);
var shuffled2 = Avx2.Shuffle(loaded2, shuffleMask);
var permuted1 = Avx2.PermuteVar8x32(shuffled1.AsInt32(), permuteMask);
var permuted2 = Avx2.PermuteVar8x32(shuffled2.AsInt32(), permuteMask);
var channel1 = Avx2.Permute2x128(permuted1, permuted2, 0x20).AsSByte();
var channel2 = Avx2.Permute2x128(permuted1, permuted2, 0x31).AsSByte();
Vector256.Store(channel1, outputPtr);
Vector256.Store(channel2, outputPtr + ch2Offset);
inputPtr += processingLength;
outputPtr_64 += 2;
outputPtr += Vector256<sbyte>.Count;
}
}
}
Expand All @@ -225,16 +234,16 @@ public static void TwoChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)
sbyte* finishPtr = inputP + input.Length;
while (inputPtr < finishPtr)
{
var loaded1 = Sse2.LoadVector128(inputPtr);
var loaded2 = Sse2.LoadVector128(inputPtr + Vector128<sbyte>.Count);
var loaded1 = Vector128.Load(inputPtr);
var loaded2 = Vector128.Load(inputPtr + Vector128<sbyte>.Count);
var shuffled1 = Ssse3.Shuffle(loaded1, shuffleMask);
var shuffled2 = Ssse3.Shuffle(loaded2, shuffleMask);
var channel1 = Sse2.UnpackLow(shuffled1.AsUInt64(), shuffled2.AsUInt64()).AsSByte();
var channel2 = Sse2.UnpackHigh(shuffled1.AsUInt64(), shuffled2.AsUInt64()).AsSByte();
Vector128.Store(channel1, outputPtr);
Vector128.Store(channel2, outputPtr + ch2Offset);
inputPtr += processingLength;
outputPtr += 16;
outputPtr += Vector128<sbyte>.Count;
}
}
}
Expand All @@ -261,7 +270,7 @@ public static void TwoChannels(ReadOnlySpan<sbyte> input, Span<sbyte> output)
AdvSimd.Store(outputPtr, loaded1.Value1);
AdvSimd.Store(outputPtr + ch2Offset, loaded1.Value2);
inputPtr += processingLength;
outputPtr += 16;
outputPtr += Vector128<sbyte>.Count;
}
}
}
Expand Down

0 comments on commit 1bb3084

Please sign in to comment.