diff --git a/benches/uint.rs b/benches/uint.rs index 7bbc9e17..4c60bd4e 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -383,13 +383,33 @@ fn bench_shl(c: &mut Criterion) { ) }); - group.bench_function("shl, U2048", |b| { + group.bench_function("shl (overflowing_shl), U2048", |b| { b.iter_batched( || U2048::ONE, |x| x.overflowing_shl(1024 + 10), BatchSize::SmallInput, ) }); + group.bench_function("shl (<<, asm), U2048", |b| { + b.iter_batched(|| U2048::ONE, |x| x << (1024 + 10), BatchSize::SmallInput) + }); + + let mut rng = make_rng(); + group.bench_function("shl (overflowing_shl), U2048, big number", |b| { + b.iter_batched( + || U2048::random(&mut rng), + |x| x.overflowing_shl(1024 + 10), + BatchSize::SmallInput, + ) + }); + let mut rng = make_rng(); + group.bench_function("shl (<<, asm), U2048, big number", |b| { + b.iter_batched( + || U2048::random(&mut rng), + |x| x << (1024 + 10), + BatchSize::SmallInput, + ) + }); group.finish(); } @@ -421,7 +441,7 @@ fn bench_shr(c: &mut Criterion) { ) }); - group.bench_function("shr, U2048", |b| { + group.bench_function("shr (overflowing_shr), U2048", |b| { b.iter_batched( || U2048::ONE, |x| x.overflowing_shr(1024 + 10), @@ -429,6 +449,44 @@ fn bench_shr(c: &mut Criterion) { ) }); + group.bench_function("shr (>>, asm), U2048", |b| { + b.iter_batched(|| U2048::ONE, |x| x >> (1024 + 10), BatchSize::SmallInput) + }); + + let mut rng = make_rng(); + group.bench_function("shr (overflowing_shr), U2048, big number", |b| { + b.iter_batched( + || U2048::random(&mut rng), + |x| x.overflowing_shr(1024 + 10), + BatchSize::SmallInput, + ) + }); + + let mut rng = make_rng(); + group.bench_function("shr (>>, asm), U2048, big number", |b| { + b.iter_batched( + || U2048::random(&mut rng), + |x| x >> (1024 + 10), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shr, U2048, zero", |b| { + b.iter_batched( + || U2048::ONE, + |x| x.overflowing_shr(0), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shr, U2048, zero, asm", |b| { + b.iter_batched( + || U2048::ONE, + |x| unsafe { x.shr_asm(0) }, + BatchSize::SmallInput, + ) + }); + group.finish(); } diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 6716988c..e9523cdb 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -5,6 +5,152 @@ use core::ops::{Shl, ShlAssign}; use subtle::CtOption; impl Uint { + /// Constant time left shift, implemented in ARM assembly. + #[allow(unsafe_code)] + #[cfg(target_arch = "aarch64")] + pub unsafe fn shl_asm(self: &Uint, shift: u32) -> Uint { + // Total bits in the multi‑precision integer. + assert!( + shift < Uint::::BITS, + "`shift` within the bit size of the integer" + ); + let mut res = Uint::ZERO; + let limbs = self.as_limbs(); + let out_limbs = res.as_limbs_mut(); + + // Split the shift into whole-limb and in-limb parts. + let limb_shift = (shift >> 6) as u64; // how many 64‑bit limbs to shift by + let bit_shift = shift & 0x3F; // remaining bit shift in [0,63] + let total_limbs = LIMBS as u64; + // A constant zero limb we can load from. + let zero: u64 = 0; + + // Loop over each output limb index (0 <= i < total_limbs) + // For each i, we want to compute: + // j = i - limb_shift, + // a = (j >= 0) ? src[j] : 0, + // b = (j-1 >= 0) ? src[j-1] : 0, + // out[i] = (a << bit_shift) | (if bit_shift != 0 { b >> (64 - bit_shift) } else { 0 }). + // + // This implementation is constant time by always iterating all limbs and using + // conditional select (csel) to avoid branches. + unsafe { + core::arch::asm!( + // x0: pointer to input limbs (limbs.as_ptr()) + // x1: pointer to output limbs (out_limbs.as_mut_ptr()) + // x2: loop index i (0 <= i < total_limbs) + // x3: bit_shift + // x4: limb_shift + // x5: total_limbs + // x12: pointer to a zero limb (&zero) + "mov x2, #0", // i = 0 + "1:", + "cmp x2, x5", // if i >= total_limbs, exit loop + "b.ge 2f", + + // Compute j = i - limb_shift into x7 (as signed) + "sub x7, x2, x4", // x7 = i - limb_shift + // Compute a_ptr = (j >= 0) ? (src + (j << 3)) : zero_ptr + "lsl x8, x7, #3", // x8 = j * 8 + "add x8, x0, x8", // tentative a_ptr = src + (j * 8) + "cmp x7, #0", // test if j < 0 + "csel x8, x12, x8, lt", // if j < 0, set a_ptr = zero_ptr + + // Compute j2 = i - limb_shift - 1 into x7 again. + "sub x7, x2, x4", // x7 = i - limb_shift + "subs x7, x7, #1", // x7 = i - limb_shift - 1 + // Compute b_ptr = (j2 >= 0) ? (src + (j2 << 3)) : zero_ptr + "lsl x9, x7, #3", // x9 = j2 * 8 + "add x9, x0, x9", // tentative b_ptr = src + (j2 * 8) + "cmp x7, #0", // test if j2 < 0 + "csel x9, x12, x9, lt", // if j2 < 0, set b_ptr = zero_ptr + + // Load limbs a and b. + "ldr x10, [x8]", // x10 = a + "ldr x11, [x9]", // x11 = b + + // Compute part_a = a << bit_shift. + "lsl x10, x10, x3", // x10 = a << bit_shift + + // Compute part_b = b >> (64 - bit_shift). + "mov x6, #64", // prepare constant 64 + "sub x6, x6, x3", // x6 = 64 - bit_shift (note: if x3==0, x6 becomes 64) + "lsr x11, x11, x6", // x11 = b >> (64 - bit_shift) + // If bit_shift is 0, force part_b to 0 (since a >> 64 would be undefined). + "cmp x3, #0", + "csel x11, xzr, x11, eq", // if bit_shift == 0, set x11 = 0 + + // Combine parts: result = part_a OR part_b. + "orr x10, x10, x11", + + // Store the result in out[i]. Compute the address: out_ptr + (i * 8). + "lsl x7, x2, #3", // offset = i * 8 + "add x7, x1, x7", // destination pointer = out_ptr + offset + "str x10, [x7]", // store the computed limb + + // Increment loop counter and repeat. + "add x2, x2, #1", + "b 1b", + "2:", + in("x0") limbs.as_ptr(), // input pointer + in("x1") out_limbs.as_mut_ptr(), // output pointer + in("x3") bit_shift, // bit shift amount + in("x4") limb_shift, // limb shift amount + in("x5") total_limbs, // total limbs + in("x12") &zero, // pointer to zero limb + lateout("x2") _, // loop index + lateout("x6") _, // temporary for (64 - bit_shift) + lateout("x7") _, // scratch (offset and index calculations) + lateout("x8") _, // pointer for a + lateout("x9") _, // pointer for b + lateout("x10") _, // holds part_a (and then combined result) + lateout("x11") _, // holds part_b + options(nostack) + ) + }; + res + } + + // TODO(dp):This works for shift < 64 –– worth keeping? + /// Constant time left shift, implemented in ARM assembly, only works for small shifts (<64). + #[allow(unsafe_code, unused)] + pub unsafe fn shl_asm_small_shift(&self, shift: u32) -> Uint { + assert!(shift < Uint::::BITS, "Shift out of bounds"); + let mut res = Uint::ZERO; + let limbs = self.as_limbs(); + let out_limbs = res.as_limbs_mut(); + + unsafe { + core::arch::asm!( + "mov x6, #0", // Init carry + + // Forward loop over the limbs (starting from low to high) + "1:", + "ldr x7, [x0], #8", // x7 ← Memory[x0], post-increment x0 + "mov x8, x7", // x8 ← x7 (preserve original limb value) + "lsl x7, x7, x3", // Left shift x7 by x3 steps + "orr x7, x7, x6", // Combine with carry + "str x7, [x1], #8", // Store shifted limb and post-increment x1 + "neg x9, x3", // x9 ← -x3 (for shift amount adjustment) + "lsr x6, x8, x9", // Right shift original limb to extract carry + "subs x2, x2, #1", // Decrement counter + "b.ne 1b", // Loop if not zero + + // Register Operand Bindings + inout("x0") limbs.as_ptr() => _, // Input pointer to source limbs + inout("x1") out_limbs.as_mut_ptr() => _, // Output pointer for result limbs + inout("x2") LIMBS => _, // Limb counter + in("x3") shift, // Shift amount + out("x6") _, // Carry register + + // Register Preservation + clobber_abi("C"), + options(nostack), + ); + } + res + } + /// Computes `self << shift`. /// /// Panics if `shift >= Self::BITS`. @@ -187,7 +333,10 @@ macro_rules! impl_shl { #[inline] fn shl(self, shift: $shift) -> Uint { - <&Self>::shl(&self, shift) + #[allow(unsafe_code)] + unsafe{ + self.shl_asm(u32::try_from(shift).expect("invalid shift")) + } } } @@ -196,13 +345,19 @@ macro_rules! impl_shl { #[inline] fn shl(self, shift: $shift) -> Uint { - Uint::::shl(self, u32::try_from(shift).expect("invalid shift")) + #[allow(unsafe_code)] + unsafe { + self.shl_asm(u32::try_from(shift).expect("invalid shift")) + } } } impl ShlAssign<$shift> for Uint { fn shl_assign(&mut self, shift: $shift) { - *self = self.shl(shift) + #[allow(unsafe_code)] + unsafe { + *self = self.shl_asm(u32::try_from(shift).expect("invalid shift")) + } } } )+ diff --git a/src/uint/shr.rs b/src/uint/shr.rs index df6db1f7..d5a3f58e 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -5,6 +5,156 @@ use core::ops::{Shr, ShrAssign}; use subtle::CtOption; impl Uint { + /// Constant time right shift, implemented in ARM assembly. + #[allow(unsafe_code)] + #[cfg(target_arch = "aarch64")] + pub unsafe fn shr_asm(&self, shift: u32) -> Uint { + // Ensure shift is less than total bits. + assert!( + shift < Uint::::BITS, + "`shift` must be less than the bit size of the integer" + ); + let mut res = Uint::ZERO; + let limbs = self.as_limbs(); + let out_limbs = res.as_limbs_mut(); + + // Split shift into whole‑limb and in‑limb parts. + let limb_shift = (shift >> 6) as u64; // number of 64-bit limbs to shift + let bit_shift = shift & 0x3F; // remaining bit shift (0..63) + let total_limbs = LIMBS as u64; + // A constant zero limb we can safely load when the index is out-of-range. + let zero: u64 = 0; + + // We now loop over each output limb index i in 0..total_limbs. + // For each output index i, compute: + // a = (if i+limb_shift < total_limbs then src[i+limb_shift] else 0) + // b = (if i+limb_shift+1 < total_limbs then src[i+limb_shift+1] else 0) + // result[i] = (if bit_shift != 0: + // (a >> bit_shift) | (b << (64-bit_shift)) + // else: + // a) + unsafe { + core::arch::asm!( + // x0: pointer to input limbs (limbs.as_ptr()) + // x1: pointer to output limbs (out_limbs.as_mut_ptr()) + // x2: loop index i (0..total_limbs) + // x3: bit_shift (0..63) + // x4: limb_shift + // x5: total_limbs + // x12: pointer to a zero limb (&zero) + "mov x2, #0", // i = 0 + "1:", + "cmp x2, x5", // if i >= total_limbs, exit loop + "b.ge 2f", + + // Compute a_index = i + limb_shift, store in x7. + "add x7, x2, x4", // x7 = i + limb_shift + "lsl x8, x7, #3", // x8 = (a_index * 8) + "add x8, x0, x8", // tentative pointer for a: src + (a_index * 8) + "cmp x7, x5", // if (i + limb_shift) < total_limbs? + "csel x8, x8, x12, lt", // if x7 < x5 then x8 remains; else use zero pointer + + // Compute b_index = i + limb_shift + 1. + "add x7, x2, x4", // x7 = i + limb_shift again + "add x7, x7, #1", // x7 = i + limb_shift + 1 + "lsl x9, x7, #3", // x9 = (b_index * 8) + "add x9, x0, x9", // tentative pointer for b: src + (b_index * 8) + "cmp x7, x5", // if (i + limb_shift + 1) < total_limbs? + "csel x9, x9, x12, lt", // if true, keep; else use zero pointer + + // Load the limbs for a and b. + "ldr x10, [x8]", // x10 = a + "ldr x11, [x9]", // x11 = b + + // For bit shifting: + // Compute part_a = a >> bit_shift. + "lsr x10, x10, x3", // x10 = a >> bit_shift + + // Compute part_b = b << (64 - bit_shift). + "mov x6, #64", // x6 = 64 + "sub x6, x6, x3", // x6 = 64 - bit_shift + "lsl x11, x11, x6", // x11 = b << (64 - bit_shift) + // If bit_shift is zero, force part_b to 0. + "cmp x3, #0", + "csel x11, xzr, x11, eq", // if bit_shift == 0, x11 = 0 + + // Combine the two parts. + "orr x10, x10, x11", // result = part_a OR part_b + + // Compute the output pointer for index i. + "lsl x7, x2, #3", // offset = i * 8 + "add x7, x1, x7", // destination pointer = out + offset + "str x10, [x7]", // store result limb + + // Increment loop index. + "add x2, x2, #1", + "b 1b", + "2:", + in("x0") limbs.as_ptr(), // input pointer + in("x1") out_limbs.as_mut_ptr(), // output pointer + in("x3") bit_shift, // bit shift value + in("x4") limb_shift, // limb shift value + in("x5") total_limbs, // total number of limbs + in("x12") &zero, // pointer to a zero limb + lateout("x2") _, // loop counter + lateout("x6") _, // scratch for (64 - bit_shift) + lateout("x7") _, // scratch for indices and offsets + lateout("x8") _, // pointer for a + lateout("x9") _, // pointer for b + lateout("x10") _, // holds shifted a / result + lateout("x11") _, // holds shifted b + options(nostack) + ) + }; + res + } + + #[allow(unsafe_code)] + #[cfg(target_arch = "aarch64")] + // TODO(dp):This works for shift < 64 –– worth keeping? + /// Constant time right shift, implemented in ARM assembly, only works for small shifts (<64). + pub unsafe fn shr_asm_small_shift(self: &Uint, shift: u32) -> Uint { + assert!( + shift < Uint::::BITS, + "`shift` within the bit size of the integer" + ); + let mut res = Uint::ZERO; + let limbs = self.as_limbs(); + let out_limbs = res.as_limbs_mut(); + unsafe { + core::arch::asm!( + "mov x6, #0", // Init carry + + // Reverse loop over the limbs (starting from high to low) + "add x0, x0, x2, LSL #3", // Move input pointer to end + "add x1, x1, x2, LSL #3", // Move output pointer to end + + "1:", + "ldr x7, [x0, #-8]!", // x7 ← Memory[x0 - 8], pre-decrement x0 + "mov x8, x7", // x8 ← x7 (preserve original limb value) + "lsr x7, x7, x3", // Right shift x7 by x3 steps + "orr x7, x7, x6", // Combine with carry + "str x7, [x1, #-8]!", // Store shifted limb and pre-decrement x1 + "neg x9, x3", // x9 ← -x3 (for shift amount adjustment) + "lsl x6, x8, x9", // Left shift original limb to extract carry + "subs x2, x2, #1", // Decrement counter + "b.ne 1b", // Loop if not zero + + // Register Operand Bindings + inout("x0") limbs.as_ptr() => _, // Input pointer to source limbs + inout("x1") out_limbs.as_mut_ptr() => _, // Output pointer for result limbs + inout("x2") LIMBS => _, // Limb counter (positive decrementing) + in("x3") shift, // Shift amount + out("x6") _, // Carry register + + // Register Preservation + clobber_abi("C"), + options(nostack), + ); + } + res + } + /// Computes `self >> shift`. /// /// Panics if `shift >= Self::BITS`. @@ -167,7 +317,8 @@ macro_rules! impl_shr { #[inline] fn shr(self, shift: $shift) -> Uint { - <&Self>::shr(&self, shift) + #[allow(unsafe_code)] + unsafe{ self.shr_asm(u32::try_from(shift).expect("invalid shift"))} } } @@ -176,13 +327,15 @@ macro_rules! impl_shr { #[inline] fn shr(self, shift: $shift) -> Uint { - Uint::::shr(self, u32::try_from(shift).expect("invalid shift")) + #[allow(unsafe_code)] + unsafe{self.shr_asm(u32::try_from(shift).expect("invalid shift"))} } } impl ShrAssign<$shift> for Uint { fn shr_assign(&mut self, shift: $shift) { - *self = self.shr(shift) + #[allow(unsafe_code)] + unsafe{*self = self.shr_asm(u32::try_from(shift).expect("invalid shift"))} } } )+ @@ -208,7 +361,7 @@ impl ShrVartime for Uint { #[cfg(test)] mod tests { - use crate::{U128, U256, Uint}; + use crate::{U64, U128, U256, Uint}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -216,10 +369,74 @@ mod tests { const N_2: U256 = U256::from_be_hex("7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF5D576E7357A4501DDFE92F46681B20A0"); + const SIXTY_FIVE: U256 = + U256::from_be_hex("00000000000000007FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF5D576E7357A4501D"); + + const EIGHTY_EIGHT: U256 = + U256::from_be_hex("0000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF"); + + const SIXTY_FOUR: U256 = + U256::from_be_hex("0000000000000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03B"); + + #[test] + fn shr_simple() { + let mut t = U256::from(2u8); + assert_eq!(t >> 1, U256::from(1u8)); + t = U256::from(0x300u16); + assert_eq!(t >> 8, U256::from(3u8)); + } + #[test] fn shr1() { - assert_eq!(N.shr1(), N_2); - assert_eq!(N >> 1, N_2); + assert_eq!(N.shr1(), N_2, "1-bit right shift, specialized"); + assert_eq!(N >> 1, N_2, "1-bit right shift, general"); + } + + #[test] + fn shr_one() { + let x = U64::from_be_hex("0000000000000002"); + let expected = U64::from_be_hex("0000000000000001"); + assert_eq!( + x >> 1, + expected, + "\nx: {x:0b}, \nexpected \n{expected:0b}, got \n{:0b}", + x >> 1, + ); + } + #[test] + fn shr_2() { + let x = U128::from_be_hex("ffffffffffffffffffffffffffffffff"); + let expected = x.overflowing_shr(1).unwrap(); + assert_eq!( + x >> 1, + expected, + "\nx: {x:0b}, \nexpected \n{expected:0b}, got \n{:0b}", + x >> 1, + ); + } + + #[test] + fn shr65() { + assert_eq!(N.overflowing_shr_vartime(65).unwrap(), SIXTY_FIVE); + assert_eq!(N >> 65, SIXTY_FIVE); + } + + #[allow(unsafe_code)] + #[test] + fn shr_asm() { + let x = + U256::from_be_hex("FFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD03641410000000000000000"); + let shift = 64; + let y = unsafe { x.shr_asm(shift) }; + let expected = + U256::from_be_hex("0000000000000000FFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); + assert_eq!(y, expected); + } + + #[test] + fn shr88() { + assert_eq!(N.overflowing_shr_vartime(88).unwrap(), EIGHTY_EIGHT); + assert_eq!(N >> 88, EIGHTY_EIGHT); } #[test] @@ -229,11 +446,17 @@ mod tests { } #[test] - #[should_panic(expected = "`shift` within the bit size of the integer")] + #[should_panic(expected = "`shift` must be less than the bit size of the integer")] fn shr256() { let _ = N >> 256; } + #[test] + fn shr64() { + assert_eq!(N.overflowing_shr_vartime(64).unwrap(), SIXTY_FOUR); + assert_eq!(N >> 64, SIXTY_FOUR); + } + #[test] fn shr_wide_1_1_128() { assert_eq!(