Skip to content

Commit bd63402

Browse files
Optimize Ascii.Equals when widening (#87141)
* Optimize Ascii.Equals when widening * add BoundedMemory tests to ensure that boundaries are respected --------- Co-authored-by: Adam Sitnik <[email protected]>
1 parent b78345e commit bd63402

File tree

2 files changed

+112
-26
lines changed

2 files changed

+112
-26
lines changed

src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
4848
|| (typeof(TLeft) == typeof(byte) && typeof(TRight) == typeof(ushort))
4949
|| (typeof(TLeft) == typeof(ushort) && typeof(TRight) == typeof(ushort)));
5050

51-
if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TRight>.Count)
51+
if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TLeft>.Count)
5252
{
5353
for (nuint i = 0; i < length; ++i)
5454
{
@@ -61,42 +61,34 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
6161
}
6262
}
6363
}
64-
else if (Avx.IsSupported && length >= (uint)Vector256<TRight>.Count)
64+
else if (Avx.IsSupported && length >= (uint)Vector256<TLeft>.Count)
6565
{
6666
ref TLeft currentLeftSearchSpace = ref left;
67-
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count256);
6867
ref TRight currentRightSearchSpace = ref right;
69-
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256<TRight>.Count);
70-
71-
Vector256<TRight> leftValues;
72-
Vector256<TRight> rightValues;
68+
// Add Vector256<TLeft>.Count because TLeft == TRight
69+
// Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
70+
Debug.Assert(Vector256<TLeft>.Count == Vector256<TRight>.Count
71+
|| (typeof(TLoader) == typeof(WideningLoader) && Vector256<TLeft>.Count == Vector256<TRight>.Count * 2));
72+
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256<TLeft>.Count);
7373

7474
// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
7575
do
7676
{
77-
leftValues = TLoader.Load256(ref currentLeftSearchSpace);
78-
rightValues = Vector256.LoadUnsafe(ref currentRightSearchSpace);
79-
80-
if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
77+
if (!TLoader.EqualAndAscii(ref currentLeftSearchSpace, ref currentRightSearchSpace))
8178
{
8279
return false;
8380
}
8481

85-
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256<TRight>.Count);
86-
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count256);
82+
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256<TLeft>.Count);
83+
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector256<TLeft>.Count);
8784
}
8885
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));
8986

9087
// If any elements remain, process the last vector in the search space.
91-
if (length % (uint)Vector256<TRight>.Count != 0)
88+
if (length % (uint)Vector256<TLeft>.Count != 0)
9289
{
93-
leftValues = TLoader.Load256(ref oneVectorAwayFromLeftEnd);
94-
rightValues = Vector256.LoadUnsafe(ref oneVectorAwayFromRightEnd);
95-
96-
if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
97-
{
98-
return false;
99-
}
90+
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector256<TLeft>.Count);
91+
return TLoader.EqualAndAscii(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
10092
}
10193
}
10294
else
@@ -363,6 +355,7 @@ private interface ILoader<TLeft, TRight>
363355
static abstract nuint Count256 { get; }
364356
static abstract Vector128<TRight> Load128(ref TLeft ptr);
365357
static abstract Vector256<TRight> Load256(ref TLeft ptr);
358+
static abstract bool EqualAndAscii(ref TLeft left, ref TRight right);
366359
}
367360

368361
private readonly struct PlainLoader<T> : ILoader<T, T> where T : unmanaged, INumberBase<T>
@@ -371,6 +364,21 @@ private interface ILoader<TLeft, TRight>
371364
public static nuint Count256 => (uint)Vector256<T>.Count;
372365
public static Vector128<T> Load128(ref T ptr) => Vector128.LoadUnsafe(ref ptr);
373366
public static Vector256<T> Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr);
367+
368+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
369+
[CompExactlyDependsOn(typeof(Avx))]
370+
public static bool EqualAndAscii(ref T left, ref T right)
371+
{
372+
Vector256<T> leftValues = Vector256.LoadUnsafe(ref left);
373+
Vector256<T> rightValues = Vector256.LoadUnsafe(ref right);
374+
375+
if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
376+
{
377+
return false;
378+
}
379+
380+
return true;
381+
}
374382
}
375383

376384
private readonly struct WideningLoader : ILoader<byte, ushort>
@@ -403,6 +411,32 @@ public static Vector256<ushort> Load256(ref byte ptr)
403411
(Vector128<ushort> lower, Vector128<ushort> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref ptr));
404412
return Vector256.Create(lower, upper);
405413
}
414+
415+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
416+
[CompExactlyDependsOn(typeof(Avx))]
417+
public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
418+
{
419+
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
420+
Debug.Assert(Vector256<byte>.Count == Vector256<ushort>.Count * 2);
421+
422+
Vector256<byte> leftNotWidened = Vector256.LoadUnsafe(ref utf8);
423+
if (!AllCharsInVectorAreAscii(leftNotWidened))
424+
{
425+
return false;
426+
}
427+
428+
(Vector256<ushort> leftLower, Vector256<ushort> leftUpper) = Vector256.Widen(leftNotWidened);
429+
Vector256<ushort> right = Vector256.LoadUnsafe(ref utf16);
430+
Vector256<ushort> rightNext = Vector256.LoadUnsafe(ref utf16, (uint)Vector256<ushort>.Count);
431+
432+
// A branchless version of "leftLower != right || leftUpper != rightNext"
433+
if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector256<ushort>.Zero)
434+
{
435+
return false;
436+
}
437+
438+
return true;
439+
}
406440
}
407441
}
408442
}

src/libraries/System.Text.Encoding/tests/Ascii/EqualsTests.cs

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Buffers;
45
using System.Collections.Generic;
56
using System.Linq;
67
using System.Runtime.Intrinsics;
78
using Xunit;
89

910
namespace System.Text.Tests
1011
{
11-
public abstract class AsciiEqualityTests
12+
public abstract class AsciiEqualityTests<TLeft, TRight>
13+
where TLeft : unmanaged
14+
where TRight : unmanaged
1215
{
1316
protected abstract bool Equals(string left, string right);
1417
protected abstract bool EqualsIgnoreCase(string left, string right);
1518
protected abstract bool Equals(byte[] left, byte[] right);
1619
protected abstract bool EqualsIgnoreCase(byte[] left, byte[] right);
20+
protected abstract bool Equals(ReadOnlySpan<TLeft> left, ReadOnlySpan<TRight> right);
21+
protected abstract bool EqualsIgnoreCase(ReadOnlySpan<TLeft> left, ReadOnlySpan<TRight> right);
1722

1823
public static IEnumerable<object[]> ValidAsciiInputs
1924
{
@@ -140,9 +145,32 @@ public void Equals_EqualValues_ButNonAscii_ReturnsFalse(byte[] input)
140145
[MemberData(nameof(ContainingNonAsciiCharactersBuffers))]
141146
public void EqualsIgnoreCase_EqualValues_ButNonAscii_ReturnsFalse(byte[] input)
142147
=> Assert.False(EqualsIgnoreCase(input, input));
148+
149+
[Theory]
150+
[InlineData(PoisonPagePlacement.After, PoisonPagePlacement.After)]
151+
[InlineData(PoisonPagePlacement.After, PoisonPagePlacement.Before)]
152+
[InlineData(PoisonPagePlacement.Before, PoisonPagePlacement.After)]
153+
[InlineData(PoisonPagePlacement.Before, PoisonPagePlacement.Before)]
154+
public void Boundaries_Are_Respected(PoisonPagePlacement leftPoison, PoisonPagePlacement rightPoison)
155+
{
156+
for (int size = 1; size < 129; size++)
157+
{
158+
using BoundedMemory<TLeft> left = BoundedMemory.Allocate<TLeft>(size, leftPoison);
159+
using BoundedMemory<TRight> right = BoundedMemory.Allocate<TRight>(size, rightPoison);
160+
161+
left.Span.Fill(default);
162+
right.Span.Fill(default);
163+
164+
left.MakeReadonly();
165+
right.MakeReadonly();
166+
167+
Assert.True(Equals(left.Span, right.Span));
168+
Assert.True(EqualsIgnoreCase(left.Span, right.Span));
169+
}
170+
}
143171
}
144172

145-
public class AsciiEqualityTests_Byte_Byte : AsciiEqualityTests
173+
public class AsciiEqualityTests_Byte_Byte : AsciiEqualityTests<byte, byte>
146174
{
147175
protected override bool Equals(string left, string right)
148176
=> Ascii.Equals(Encoding.ASCII.GetBytes(left), Encoding.ASCII.GetBytes(right));
@@ -155,9 +183,15 @@ protected override bool Equals(byte[] left, byte[] right)
155183

156184
protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
157185
=> Ascii.EqualsIgnoreCase(left, right);
186+
187+
protected override bool Equals(ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
188+
=> Ascii.Equals(left, right);
189+
190+
protected override bool EqualsIgnoreCase(ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
191+
=> Ascii.EqualsIgnoreCase(left, right);
158192
}
159193

160-
public class AsciiEqualityTests_Byte_Char : AsciiEqualityTests
194+
public class AsciiEqualityTests_Byte_Char : AsciiEqualityTests<byte, char>
161195
{
162196
protected override bool Equals(string left, string right)
163197
=> Ascii.Equals(Encoding.ASCII.GetBytes(left), right);
@@ -170,9 +204,15 @@ protected override bool Equals(byte[] left, byte[] right)
170204

171205
protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
172206
=> Ascii.EqualsIgnoreCase(left, right.Select(b => (char)b).ToArray());
207+
208+
protected override bool Equals(ReadOnlySpan<byte> left, ReadOnlySpan<char> right)
209+
=> Ascii.Equals(left, right);
210+
211+
protected override bool EqualsIgnoreCase(ReadOnlySpan<byte> left, ReadOnlySpan<char> right)
212+
=> Ascii.EqualsIgnoreCase(left, right);
173213
}
174214

175-
public class AsciiEqualityTests_Char_Byte : AsciiEqualityTests
215+
public class AsciiEqualityTests_Char_Byte : AsciiEqualityTests<char, byte>
176216
{
177217
protected override bool Equals(string left, string right)
178218
=> Ascii.Equals(left, Encoding.ASCII.GetBytes(right));
@@ -185,9 +225,15 @@ protected override bool Equals(byte[] left, byte[] right)
185225

186226
protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
187227
=> Ascii.EqualsIgnoreCase(left.Select(b => (char)b).ToArray(), right);
228+
229+
protected override bool Equals(ReadOnlySpan<char> left, ReadOnlySpan<byte> right)
230+
=> Ascii.Equals(left, right);
231+
232+
protected override bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<byte> right)
233+
=> Ascii.EqualsIgnoreCase(left, right);
188234
}
189235

190-
public class AsciiEqualityTests_Char_Char : AsciiEqualityTests
236+
public class AsciiEqualityTests_Char_Char : AsciiEqualityTests<char, char>
191237
{
192238
protected override bool Equals(string left, string right)
193239
=> Ascii.Equals(left, right);
@@ -200,5 +246,11 @@ protected override bool Equals(byte[] left, byte[] right)
200246

201247
protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
202248
=> Ascii.EqualsIgnoreCase(left.Select(b => (char)b).ToArray(), right.Select(b => (char)b).ToArray());
249+
250+
protected override bool Equals(ReadOnlySpan<char> left, ReadOnlySpan<char> right)
251+
=> Ascii.Equals(left, right);
252+
253+
protected override bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<char> right)
254+
=> Ascii.EqualsIgnoreCase(left, right);
203255
}
204256
}

0 commit comments

Comments
 (0)