Skip to content

impl shr in Arm assembly #798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions benches/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,33 @@ fn bench_shl(c: &mut Criterion) {
)
});

group.bench_function("shl, U2048", |b| {
group.bench_function("shl (overflowing_shl), U2048", |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shl(1024 + 10),
BatchSize::SmallInput,
)
});
group.bench_function("shl (<<, asm), U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x << (1024 + 10), BatchSize::SmallInput)
});

let mut rng = make_rng();
group.bench_function("shl (overflowing_shl), U2048, big number", |b| {
b.iter_batched(
|| U2048::random(&mut rng),
|x| x.overflowing_shl(1024 + 10),
BatchSize::SmallInput,
)
});
let mut rng = make_rng();
group.bench_function("shl (<<, asm), U2048, big number", |b| {
b.iter_batched(
|| U2048::random(&mut rng),
|x| x << (1024 + 10),
BatchSize::SmallInput,
)
});

group.finish();
}
Expand Down Expand Up @@ -421,14 +441,52 @@ fn bench_shr(c: &mut Criterion) {
)
});

group.bench_function("shr, U2048", |b| {
group.bench_function("shr (overflowing_shr), U2048", |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shr(1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr (>>, asm), U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x >> (1024 + 10), BatchSize::SmallInput)
});

let mut rng = make_rng();
group.bench_function("shr (overflowing_shr), U2048, big number", |b| {
b.iter_batched(
|| U2048::random(&mut rng),
|x| x.overflowing_shr(1024 + 10),
BatchSize::SmallInput,
)
});

let mut rng = make_rng();
group.bench_function("shr (>>, asm), U2048, big number", |b| {
b.iter_batched(
|| U2048::random(&mut rng),
|x| x >> (1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr, U2048, zero", |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shr(0),
BatchSize::SmallInput,
)
});

group.bench_function("shr, U2048, zero, asm", |b| {
b.iter_batched(
|| U2048::ONE,
|x| unsafe { x.shr_asm(0) },
BatchSize::SmallInput,
)
});

group.finish();
}

Expand Down
161 changes: 158 additions & 3 deletions src/uint/shl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,152 @@ use core::ops::{Shl, ShlAssign};
use subtle::CtOption;

impl<const LIMBS: usize> Uint<LIMBS> {
/// Constant time left shift, implemented in ARM assembly.
#[allow(unsafe_code)]
#[cfg(target_arch = "aarch64")]
pub unsafe fn shl_asm(self: &Uint<LIMBS>, shift: u32) -> Uint<LIMBS> {
// Total bits in the multi‑precision integer.
assert!(
shift < Uint::<LIMBS>::BITS,
"`shift` within the bit size of the integer"
);
let mut res = Uint::ZERO;
let limbs = self.as_limbs();
let out_limbs = res.as_limbs_mut();

// Split the shift into whole-limb and in-limb parts.
let limb_shift = (shift >> 6) as u64; // how many 64‑bit limbs to shift by
let bit_shift = shift & 0x3F; // remaining bit shift in [0,63]
let total_limbs = LIMBS as u64;
// A constant zero limb we can load from.
let zero: u64 = 0;

// Loop over each output limb index (0 <= i < total_limbs)
// For each i, we want to compute:
// j = i - limb_shift,
// a = (j >= 0) ? src[j] : 0,
// b = (j-1 >= 0) ? src[j-1] : 0,
// out[i] = (a << bit_shift) | (if bit_shift != 0 { b >> (64 - bit_shift) } else { 0 }).
//
// This implementation is constant time by always iterating all limbs and using
// conditional select (csel) to avoid branches.
unsafe {
core::arch::asm!(
// x0: pointer to input limbs (limbs.as_ptr())
// x1: pointer to output limbs (out_limbs.as_mut_ptr())
// x2: loop index i (0 <= i < total_limbs)
// x3: bit_shift
// x4: limb_shift
// x5: total_limbs
// x12: pointer to a zero limb (&zero)
"mov x2, #0", // i = 0
"1:",
"cmp x2, x5", // if i >= total_limbs, exit loop
"b.ge 2f",

// Compute j = i - limb_shift into x7 (as signed)
"sub x7, x2, x4", // x7 = i - limb_shift
// Compute a_ptr = (j >= 0) ? (src + (j << 3)) : zero_ptr
"lsl x8, x7, #3", // x8 = j * 8
"add x8, x0, x8", // tentative a_ptr = src + (j * 8)
"cmp x7, #0", // test if j < 0
"csel x8, x12, x8, lt", // if j < 0, set a_ptr = zero_ptr

// Compute j2 = i - limb_shift - 1 into x7 again.
"sub x7, x2, x4", // x7 = i - limb_shift
"subs x7, x7, #1", // x7 = i - limb_shift - 1
// Compute b_ptr = (j2 >= 0) ? (src + (j2 << 3)) : zero_ptr
"lsl x9, x7, #3", // x9 = j2 * 8
"add x9, x0, x9", // tentative b_ptr = src + (j2 * 8)
"cmp x7, #0", // test if j2 < 0
"csel x9, x12, x9, lt", // if j2 < 0, set b_ptr = zero_ptr

// Load limbs a and b.
"ldr x10, [x8]", // x10 = a
"ldr x11, [x9]", // x11 = b

// Compute part_a = a << bit_shift.
"lsl x10, x10, x3", // x10 = a << bit_shift

// Compute part_b = b >> (64 - bit_shift).
"mov x6, #64", // prepare constant 64
"sub x6, x6, x3", // x6 = 64 - bit_shift (note: if x3==0, x6 becomes 64)
"lsr x11, x11, x6", // x11 = b >> (64 - bit_shift)
// If bit_shift is 0, force part_b to 0 (since a >> 64 would be undefined).
"cmp x3, #0",
"csel x11, xzr, x11, eq", // if bit_shift == 0, set x11 = 0

// Combine parts: result = part_a OR part_b.
"orr x10, x10, x11",

// Store the result in out[i]. Compute the address: out_ptr + (i * 8).
"lsl x7, x2, #3", // offset = i * 8
"add x7, x1, x7", // destination pointer = out_ptr + offset
"str x10, [x7]", // store the computed limb

// Increment loop counter and repeat.
"add x2, x2, #1",
"b 1b",
"2:",
in("x0") limbs.as_ptr(), // input pointer
in("x1") out_limbs.as_mut_ptr(), // output pointer
in("x3") bit_shift, // bit shift amount
in("x4") limb_shift, // limb shift amount
in("x5") total_limbs, // total limbs
in("x12") &zero, // pointer to zero limb
lateout("x2") _, // loop index
lateout("x6") _, // temporary for (64 - bit_shift)
lateout("x7") _, // scratch (offset and index calculations)
lateout("x8") _, // pointer for a
lateout("x9") _, // pointer for b
lateout("x10") _, // holds part_a (and then combined result)
lateout("x11") _, // holds part_b
options(nostack)
)
};
res
}

// TODO(dp):This works for shift < 64 –– worth keeping?
/// Constant time left shift, implemented in ARM assembly, only works for small shifts (<64).
#[allow(unsafe_code, unused)]
pub unsafe fn shl_asm_small_shift(&self, shift: u32) -> Uint<LIMBS> {
assert!(shift < Uint::<LIMBS>::BITS, "Shift out of bounds");
let mut res = Uint::ZERO;
let limbs = self.as_limbs();
let out_limbs = res.as_limbs_mut();

unsafe {
core::arch::asm!(
"mov x6, #0", // Init carry

// Forward loop over the limbs (starting from low to high)
"1:",
"ldr x7, [x0], #8", // x7 ← Memory[x0], post-increment x0
"mov x8, x7", // x8 ← x7 (preserve original limb value)
"lsl x7, x7, x3", // Left shift x7 by x3 steps
"orr x7, x7, x6", // Combine with carry
"str x7, [x1], #8", // Store shifted limb and post-increment x1
"neg x9, x3", // x9 ← -x3 (for shift amount adjustment)
"lsr x6, x8, x9", // Right shift original limb to extract carry
"subs x2, x2, #1", // Decrement counter
"b.ne 1b", // Loop if not zero

// Register Operand Bindings
inout("x0") limbs.as_ptr() => _, // Input pointer to source limbs
inout("x1") out_limbs.as_mut_ptr() => _, // Output pointer for result limbs
inout("x2") LIMBS => _, // Limb counter
in("x3") shift, // Shift amount
out("x6") _, // Carry register

// Register Preservation
clobber_abi("C"),
options(nostack),
);
}
res
}

/// Computes `self << shift`.
///
/// Panics if `shift >= Self::BITS`.
Expand Down Expand Up @@ -187,7 +333,10 @@ macro_rules! impl_shl {

#[inline]
fn shl(self, shift: $shift) -> Uint<LIMBS> {
<&Self>::shl(&self, shift)
#[allow(unsafe_code)]
unsafe{
self.shl_asm(u32::try_from(shift).expect("invalid shift"))
}
}
}

Expand All @@ -196,13 +345,19 @@ macro_rules! impl_shl {

#[inline]
fn shl(self, shift: $shift) -> Uint<LIMBS> {
Uint::<LIMBS>::shl(self, u32::try_from(shift).expect("invalid shift"))
#[allow(unsafe_code)]
unsafe {
self.shl_asm(u32::try_from(shift).expect("invalid shift"))
}
}
}

impl<const LIMBS: usize> ShlAssign<$shift> for Uint<LIMBS> {
fn shl_assign(&mut self, shift: $shift) {
*self = self.shl(shift)
#[allow(unsafe_code)]
unsafe {
*self = self.shl_asm(u32::try_from(shift).expect("invalid shift"))
}
}
}
)+
Expand Down
Loading
Loading