Skip to content

Commit 2959612

Browse files
authored
Fix BigInteger.Rotate{Left,Right} for backport (#112878)
* Add BigInteger.Rotate* tests * Fix BigInteger.Rotate* * avoid stackalloc * Add comment
1 parent b54529f commit 2959612

File tree

5 files changed

+799
-30
lines changed

5 files changed

+799
-30
lines changed

src/libraries/System.Runtime.Numerics/src/System/Numerics/BigInteger.cs

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,7 +1701,7 @@ private static BigInteger Add(ReadOnlySpan<uint> leftBits, int leftSign, ReadOnl
17011701
}
17021702

17031703
if (bitsFromPool != null)
1704-
ArrayPool<uint>.Shared.Return(bitsFromPool);
1704+
ArrayPool<uint>.Shared.Return(bitsFromPool);
17051705

17061706
return result;
17071707
}
@@ -2636,7 +2636,7 @@ public static implicit operator BigInteger(nuint value)
26362636

26372637
if (zdFromPool != null)
26382638
ArrayPool<uint>.Shared.Return(zdFromPool);
2639-
exit:
2639+
exit:
26402640
if (xdFromPool != null)
26412641
ArrayPool<uint>.Shared.Return(xdFromPool);
26422642

@@ -3239,7 +3239,27 @@ public static BigInteger PopCount(BigInteger value)
32393239
public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
32403240
{
32413241
value.AssertValid();
3242-
int byteCount = (value._bits is null) ? sizeof(int) : (value._bits.Length * 4);
3242+
3243+
bool negx = value._sign < 0;
3244+
uint smallBits = NumericsHelpers.Abs(value._sign);
3245+
scoped ReadOnlySpan<uint> bits = value._bits;
3246+
if (bits.IsEmpty)
3247+
{
3248+
bits = new ReadOnlySpan<uint>(in smallBits);
3249+
}
3250+
3251+
int xl = bits.Length;
3252+
if (negx && (bits[^1] >= kuMaskHighBit) && ((bits[^1] != kuMaskHighBit) || bits.IndexOfAnyExcept(0u) != (bits.Length - 1)))
3253+
{
3254+
// We check for a special case where its sign bit could be outside the uint array after 2's complement conversion.
3255+
// For example given [0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF], its 2's complement is [0x01, 0x00, 0x00]
3256+
// After a 32 bit right shift, it becomes [0x00, 0x00] which is [0x00, 0x00] when converted back.
3257+
// The expected result is [0x00, 0x00, 0xFFFFFFFF] (2's complement) or [0x00, 0x00, 0x01] when converted back
3258+
// If the 2's component's last element is a 0, we will track the sign externally
3259+
++xl;
3260+
}
3261+
3262+
int byteCount = xl * 4;
32433263

32443264
// Normalize the rotate amount to drop full rotations
32453265
rotateAmount = (int)(rotateAmount % (byteCount * 8L));
@@ -3256,14 +3276,13 @@ public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
32563276
(int digitShift, int smallShift) = Math.DivRem(rotateAmount, kcbitUint);
32573277

32583278
uint[]? xdFromPool = null;
3259-
int xl = value._bits?.Length ?? 1;
3260-
32613279
Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold)
32623280
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
32633281
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl);
32643282
xd = xd.Slice(0, xl);
3283+
xd[^1] = 0;
32653284

3266-
bool negx = value.GetPartsForBitManipulation(xd);
3285+
bits.CopyTo(xd);
32673286

32683287
int zl = xl;
32693288
uint[]? zdFromPool = null;
@@ -3374,7 +3393,28 @@ public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
33743393
public static BigInteger RotateRight(BigInteger value, int rotateAmount)
33753394
{
33763395
value.AssertValid();
3377-
int byteCount = (value._bits is null) ? sizeof(int) : (value._bits.Length * 4);
3396+
3397+
3398+
bool negx = value._sign < 0;
3399+
uint smallBits = NumericsHelpers.Abs(value._sign);
3400+
scoped ReadOnlySpan<uint> bits = value._bits;
3401+
if (bits.IsEmpty)
3402+
{
3403+
bits = new ReadOnlySpan<uint>(in smallBits);
3404+
}
3405+
3406+
int xl = bits.Length;
3407+
if (negx && (bits[^1] >= kuMaskHighBit) && ((bits[^1] != kuMaskHighBit) || bits.IndexOfAnyExcept(0u) != (bits.Length - 1)))
3408+
{
3409+
// We check for a special case where its sign bit could be outside the uint array after 2's complement conversion.
3410+
// For example given [0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF], its 2's complement is [0x01, 0x00, 0x00]
3411+
// After a 32 bit right shift, it becomes [0x00, 0x00] which is [0x00, 0x00] when converted back.
3412+
// The expected result is [0x00, 0x00, 0xFFFFFFFF] (2's complement) or [0x00, 0x00, 0x01] when converted back
3413+
// If the 2's component's last element is a 0, we will track the sign externally
3414+
++xl;
3415+
}
3416+
3417+
int byteCount = xl * 4;
33783418

33793419
// Normalize the rotate amount to drop full rotations
33803420
rotateAmount = (int)(rotateAmount % (byteCount * 8L));
@@ -3391,14 +3431,13 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
33913431
(int digitShift, int smallShift) = Math.DivRem(rotateAmount, kcbitUint);
33923432

33933433
uint[]? xdFromPool = null;
3394-
int xl = value._bits?.Length ?? 1;
3395-
33963434
Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold)
33973435
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
33983436
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl);
33993437
xd = xd.Slice(0, xl);
3438+
xd[^1] = 0;
34003439

3401-
bool negx = value.GetPartsForBitManipulation(xd);
3440+
bits.CopyTo(xd);
34023441

34033442
int zl = xl;
34043443
uint[]? zdFromPool = null;
@@ -3445,19 +3484,12 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
34453484
{
34463485
int carryShift = kcbitUint - smallShift;
34473486

3448-
int dstIndex = 0;
3449-
int srcIndex = digitShift;
3487+
int dstIndex = xd.Length - 1;
3488+
int srcIndex = digitShift == 0
3489+
? xd.Length - 1
3490+
: digitShift - 1;
34503491

3451-
uint carry = 0;
3452-
3453-
if (digitShift == 0)
3454-
{
3455-
carry = xd[^1] << carryShift;
3456-
}
3457-
else
3458-
{
3459-
carry = xd[srcIndex - 1] << carryShift;
3460-
}
3492+
uint carry = xd[digitShift] << carryShift;
34613493

34623494
do
34633495
{
@@ -3466,22 +3498,22 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
34663498
zd[dstIndex] = (part >> smallShift) | carry;
34673499
carry = part << carryShift;
34683500

3469-
dstIndex++;
3470-
srcIndex++;
3501+
dstIndex--;
3502+
srcIndex--;
34713503
}
3472-
while (srcIndex < xd.Length);
3504+
while ((uint)srcIndex < (uint)xd.Length); // is equivalent to (srcIndex >= 0 && srcIndex < xd.Length)
34733505

3474-
srcIndex = 0;
3506+
srcIndex = xd.Length - 1;
34753507

3476-
while (dstIndex < zd.Length)
3508+
while ((uint)dstIndex < (uint)zd.Length) // is equivalent to (dstIndex >= 0 && dstIndex < zd.Length)
34773509
{
34783510
uint part = xd[srcIndex];
34793511

34803512
zd[dstIndex] = (part >> smallShift) | carry;
34813513
carry = part << carryShift;
34823514

3483-
dstIndex++;
3484-
srcIndex++;
3515+
dstIndex--;
3516+
srcIndex--;
34853517
}
34863518
}
34873519

src/libraries/System.Runtime.Numerics/tests/BigInteger/MyBigInt.cs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ public static BigInteger DoBinaryOperatorMine(BigInteger num1, BigInteger num2,
110110
return new BigInteger(ShiftLeft(bytes1, Negate(bytes2)).ToArray());
111111
case "b<<":
112112
return new BigInteger(ShiftLeft(bytes1, bytes2).ToArray());
113+
case "bRotateLeft":
114+
return new BigInteger(RotateLeft(bytes1, bytes2).ToArray());
115+
case "bRotateRight":
116+
return new BigInteger(RotateLeft(bytes1, Negate(bytes2)).ToArray());
113117
case "b^":
114118
return new BigInteger(Xor(bytes1, bytes2).ToArray());
115119
case "b|":
@@ -774,6 +778,105 @@ public static List<byte> ShiftRight(List<byte> bytes)
774778
return bresult;
775779
}
776780

781+
public static List<byte> RotateRight(List<byte> bytes)
782+
{
783+
List<byte> bresult = new List<byte>();
784+
785+
byte bottom = (byte)(bytes[0] & 0x01);
786+
787+
for (int i = 0; i < bytes.Count; i++)
788+
{
789+
byte newbyte = bytes[i];
790+
791+
newbyte = (byte)(newbyte / 2);
792+
if ((i != (bytes.Count - 1)) && ((bytes[i + 1] & 0x01) == 1))
793+
{
794+
newbyte += 128;
795+
}
796+
if ((i == (bytes.Count - 1)) && (bottom != 0))
797+
{
798+
newbyte += 128;
799+
}
800+
bresult.Add(newbyte);
801+
}
802+
803+
return bresult;
804+
}
805+
806+
public static List<byte> RotateLeft(List<byte> bytes)
807+
{
808+
List<byte> bresult = new List<byte>();
809+
810+
bool prevHead = (bytes[bytes.Count - 1] & 0x80) != 0;
811+
812+
for (int i = 0; i < bytes.Count; i++)
813+
{
814+
byte newbyte = bytes[i];
815+
816+
newbyte = (byte)(newbyte * 2);
817+
if (prevHead)
818+
{
819+
newbyte += 1;
820+
}
821+
822+
bresult.Add(newbyte);
823+
824+
prevHead = (bytes[i] & 0x80) != 0;
825+
}
826+
827+
return bresult;
828+
}
829+
830+
831+
public static List<byte> RotateLeft(List<byte> bytes1, List<byte> bytes2)
832+
{
833+
List<byte> bytes1Copy = Copy(bytes1);
834+
int byteShift = (int)new BigInteger(Divide(Copy(bytes2), new List<byte>(new byte[] { 8 })).ToArray());
835+
sbyte bitShift = (sbyte)new BigInteger(Remainder(bytes2, new List<byte>(new byte[] { 8 })).ToArray());
836+
837+
Trim(bytes1);
838+
839+
byte fill = (bytes1[bytes1.Count - 1] & 0x80) != 0 ? byte.MaxValue : (byte)0;
840+
841+
if (fill == 0 && bytes1.Count > 1 && bytes1[bytes1.Count - 1] == 0)
842+
bytes1.RemoveAt(bytes1.Count - 1);
843+
844+
while (bytes1.Count % 4 != 0)
845+
{
846+
bytes1.Add(fill);
847+
}
848+
849+
byteShift %= bytes1.Count;
850+
if (byteShift == 0 && bitShift == 0)
851+
return bytes1Copy;
852+
853+
for (int i = 0; i < Math.Abs(bitShift); i++)
854+
{
855+
if (bitShift < 0)
856+
{
857+
bytes1 = RotateRight(bytes1);
858+
}
859+
else
860+
{
861+
bytes1 = RotateLeft(bytes1);
862+
}
863+
}
864+
865+
List<byte> temp = new List<byte>();
866+
for (int i = 0; i < bytes1.Count; i++)
867+
{
868+
temp.Add(bytes1[(i - byteShift + bytes1.Count) % bytes1.Count]);
869+
}
870+
bytes1 = temp;
871+
872+
if (fill == 0)
873+
bytes1.Add(0);
874+
875+
Trim(bytes1);
876+
877+
return bytes1;
878+
}
879+
777880
public static List<byte> SetLength(List<byte> bytes, int size)
778881
{
779882
List<byte> bresult = new List<byte>();

0 commit comments

Comments
 (0)