Skip to content

Commit 8520fdb

Browse files
Faster constant-time division (#643)
This replaces the `div_rem` method for Uint and BoxedUint with one based on the updated vartime division method. Unlike the old method, this one panics if `NonZero(0)` is given as the divisor. The `sqrt` method relied on the old behavior and is updated. The `Default` implementation of `NonZero` was also producing invalid `NonZero(0)` values, which affected the `NonZero::map` usage in `checked_div`. The trait implementation is updated to return `Self::ONE` as the default. Signed-off-by: Andrew Whitehead <[email protected]>
1 parent e362a47 commit 8520fdb

File tree

9 files changed

+227
-137
lines changed

9 files changed

+227
-137
lines changed

benches/uint.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,14 @@ fn bench_sqrt(c: &mut Criterion) {
320320
BatchSize::SmallInput,
321321
)
322322
});
323+
324+
group.bench_function("sqrt_vartime, U256", |b| {
325+
b.iter_batched(
326+
|| U256::random(&mut OsRng),
327+
|x| x.sqrt_vartime(),
328+
BatchSize::SmallInput,
329+
)
330+
});
323331
}
324332

325333
criterion_group!(

src/limb/shl.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@ impl Limb {
1111
pub const fn shl(self, shift: u32) -> Self {
1212
Limb(self.0 << shift)
1313
}
14-
15-
/// Computes `self << 1` and return the result and the carry (0 or 1).
16-
#[inline(always)]
17-
pub(crate) const fn shl1(self) -> (Self, Self) {
18-
(Self(self.0 << 1), Self(self.0 >> Self::HI_BIT))
19-
}
2014
}
2115

2216
macro_rules! impl_shl {

src/non_zero.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use serdect::serde::{
2121
};
2222

2323
/// Wrapper type for non-zero integers.
24-
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)]
24+
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
2525
#[repr(transparent)]
2626
pub struct NonZero<T>(pub(crate) T);
2727

@@ -210,6 +210,15 @@ where
210210
}
211211
}
212212

213+
impl<T> Default for NonZero<T>
214+
where
215+
T: Constants,
216+
{
217+
fn default() -> Self {
218+
Self(T::ONE)
219+
}
220+
}
221+
213222
impl<T> Deref for NonZero<T> {
214223
type Target = T;
215224

src/uint/boxed/div.rs

Lines changed: 103 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
use crate::{
44
uint::{boxed, div_limb::div3by2},
55
BoxedUint, CheckedDiv, ConstChoice, ConstantTimeSelect, DivRemLimb, Limb, NonZero, Reciprocal,
6-
RemLimb, Wrapping,
6+
RemLimb, Word, Wrapping,
77
};
88
use core::ops::{Div, DivAssign, Rem, RemAssign};
9-
use subtle::{Choice, ConstantTimeLess, CtOption};
9+
use subtle::CtOption;
1010

1111
impl BoxedUint {
1212
/// Computes `self / rhs` using a pre-made reciprocal,
@@ -118,41 +118,113 @@ impl BoxedUint {
118118
/// Perform checked division, returning a [`CtOption`] which `is_some`
119119
/// only if the rhs != 0
120120
pub fn checked_div(&self, rhs: &Self) -> CtOption<Self> {
121-
let q = self.div_rem_unchecked(rhs).0;
122-
CtOption::new(q, !rhs.is_zero())
121+
let is_nz = rhs.is_nonzero();
122+
let nz = NonZero(Self::ct_select(
123+
&Self::one_with_precision(self.bits_precision()),
124+
rhs,
125+
is_nz,
126+
));
127+
let q = self.div_rem_unchecked(&nz).0;
128+
CtOption::new(q, is_nz)
123129
}
124130

125-
/// Computes `self` / `rhs`, returns the quotient (q), remainder (r) without checking if `rhs`
126-
/// is zero.
127-
///
128-
/// This function is constant-time with respect to both `self` and `rhs`.
129131
fn div_rem_unchecked(&self, rhs: &Self) -> (Self, Self) {
130-
debug_assert_eq!(self.bits_precision(), rhs.bits_precision());
131-
let mb = rhs.bits();
132-
let bits_precision = self.bits_precision();
133-
let mut rem = self.clone();
134-
let mut quo = Self::zero_with_precision(bits_precision);
135-
let (mut c, _overflow) = rhs.overflowing_shl(bits_precision - mb);
136-
let mut i = bits_precision;
137-
let mut done = Choice::from(0u8);
138-
139-
loop {
140-
let (mut r, borrow) = rem.sbb(&c, Limb::ZERO);
141-
rem.ct_assign(&r, !(Choice::from((borrow.0 & 1) as u8) | done));
142-
r = quo.bitor(&Self::one());
143-
quo.ct_assign(&r, !(Choice::from((borrow.0 & 1) as u8) | done));
144-
if i == 0 {
145-
break;
132+
// Based on Section 4.3.1, of The Art of Computer Programming, Volume 2, by Donald E. Knuth.
133+
// Further explanation at https://janmr.com/blog/2014/04/basic-multiple-precision-long-division/
134+
135+
let size = self.limbs.len();
136+
assert_eq!(
137+
size,
138+
rhs.limbs.len(),
139+
"the precision of the divisor must match the dividend"
140+
);
141+
142+
// Short circuit for single-word precision
143+
if size == 1 {
144+
let (quo, rem_limb) = self.div_rem_limb(rhs.limbs[0].to_nz().expect("zero divisor"));
145+
let mut rem = Self::zero_with_precision(self.bits_precision());
146+
rem.limbs[0] = rem_limb;
147+
return (quo, rem);
148+
}
149+
150+
let dbits = rhs.bits();
151+
assert!(dbits > 0, "zero divisor");
152+
let dwords = (dbits + Limb::BITS - 1) / Limb::BITS;
153+
let lshift = (Limb::BITS - (dbits % Limb::BITS)) % Limb::BITS;
154+
155+
// Shift entire divisor such that the high bit is set
156+
let mut y = rhs.shl((size as u32) * Limb::BITS - dbits).to_limbs();
157+
// Shift the dividend to align the words
158+
let (x, mut x_hi) = self.shl_limb(lshift);
159+
let mut x = x.to_limbs();
160+
let mut xi = size - 1;
161+
let mut x_lo = x[size - 1];
162+
let mut i;
163+
let mut carry;
164+
165+
let reciprocal = Reciprocal::new(y[size - 1].to_nz().expect("zero divisor"));
166+
167+
while xi > 0 {
168+
// Divide high dividend words by the high divisor word to estimate the quotient word
169+
let (mut quo, _) = div3by2(x_hi.0, x_lo.0, x[xi - 1].0, &reciprocal, y[size - 2].0);
170+
171+
// This loop is a no-op once xi is smaller than the number of words in the divisor
172+
let done = ConstChoice::from_u32_lt(xi as u32, dwords - 1);
173+
quo = done.select_word(quo, 0);
174+
175+
// Subtract q*divisor from the dividend
176+
carry = Limb::ZERO;
177+
let mut borrow = Limb::ZERO;
178+
let mut tmp;
179+
i = 0;
180+
while i <= xi {
181+
(tmp, carry) = Limb::ZERO.mac(y[size - xi + i - 1], Limb(quo), carry);
182+
(x[i], borrow) = x[i].sbb(tmp, borrow);
183+
i += 1;
146184
}
147-
i -= 1;
148-
// when `i < mb`, the computation is actually done, so we ensure `quo` and `rem`
149-
// aren't modified further (but do the remaining iterations anyway to be constant-time)
150-
done = i.ct_lt(&mb);
151-
c.shr1_assign();
152-
quo.ct_assign(&quo.shl1(), !done);
185+
(_, borrow) = x_hi.sbb(carry, borrow);
186+
187+
// If the subtraction borrowed, then decrement q and add back the divisor
188+
// The probability of this being needed is very low, about 2/(Limb::MAX+1)
189+
let ct_borrow = ConstChoice::from_word_mask(borrow.0);
190+
carry = Limb::ZERO;
191+
i = 0;
192+
while i <= xi {
193+
(x[i], carry) = x[i].adc(
194+
Limb::select(Limb::ZERO, y[size - xi + i - 1], ct_borrow),
195+
carry,
196+
);
197+
i += 1;
198+
}
199+
quo = ct_borrow.select_word(quo, quo.saturating_sub(1));
200+
201+
// Store the quotient within dividend and set x_hi to the current highest word
202+
x_hi = Limb::select(x[xi], x_hi, done);
203+
x[xi] = Limb::select(Limb(quo), x[xi], done);
204+
x_lo = Limb::select(x[xi - 1], x_lo, done);
205+
xi -= 1;
206+
}
207+
208+
let limb_div = ConstChoice::from_u32_eq(1, dwords);
209+
// Calculate quotient and remainder for the case where the divisor is a single word
210+
let (quo2, rem2) = div3by2(x_hi.0, x_lo.0, 0, &reciprocal, 0);
211+
212+
// Adjust the quotient for single limb division
213+
x[0] = Limb::select(x[0], Limb(quo2), limb_div);
214+
215+
// Copy out the remainder
216+
y[0] = Limb::select(x[0], Limb(rem2 as Word), limb_div);
217+
i = 1;
218+
while i < size {
219+
y[i] = Limb::select(Limb::ZERO, x[i], ConstChoice::from_u32_lt(i as u32, dwords));
220+
y[i] = Limb::select(y[i], x_hi, ConstChoice::from_u32_eq(i as u32, dwords - 1));
221+
i += 1;
153222
}
154223

155-
(quo, rem)
224+
(
225+
Self { limbs: x }.shr((dwords - 1) * Limb::BITS),
226+
Self { limbs: y }.shr(lshift),
227+
)
156228
}
157229
}
158230

src/uint/boxed/shl.rs

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -126,25 +126,6 @@ impl BoxedUint {
126126
success.map(|_| result)
127127
}
128128

129-
/// Computes `self << 1` in constant-time.
130-
pub(crate) fn shl1(&self) -> Self {
131-
let mut ret = self.clone();
132-
ret.shl1_assign();
133-
ret
134-
}
135-
136-
/// Computes `self << 1` in-place in constant-time.
137-
pub(crate) fn shl1_assign(&mut self) {
138-
let mut carry = self.limbs[0].0 >> Limb::HI_BIT;
139-
self.limbs[0].shl_assign(1);
140-
for i in 1..self.limbs.len() {
141-
let new_carry = self.limbs[i].0 >> Limb::HI_BIT;
142-
self.limbs[i].shl_assign(1);
143-
self.limbs[i].0 |= carry;
144-
carry = new_carry
145-
}
146-
}
147-
148129
/// Computes `self << shift` where `0 <= shift < Limb::BITS`,
149130
/// returning the result and the carry.
150131
pub(crate) fn shl_limb(&self, shift: u32) -> (Self, Limb) {
@@ -230,14 +211,6 @@ impl ShlVartime for BoxedUint {
230211
mod tests {
231212
use super::BoxedUint;
232213

233-
#[test]
234-
fn shl1_assign() {
235-
let mut n = BoxedUint::from(0x3c442b21f19185fe433f0a65af902b8fu128);
236-
let n_shl1 = BoxedUint::from(0x78885643e3230bfc867e14cb5f20571eu128);
237-
n.shl1_assign();
238-
assert_eq!(n, n_shl1);
239-
}
240-
241214
#[test]
242215
fn shl() {
243216
let one = BoxedUint::one_with_precision(128);

src/uint/boxed/sqrt.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! [`BoxedUint`] square root operations.
22
3-
use subtle::{ConstantTimeEq, ConstantTimeGreater, CtOption};
3+
use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, CtOption};
44

55
use crate::{BitOps, BoxedUint, ConstantTimeSelect, NonZero, SquareRoot};
66

@@ -23,24 +23,22 @@ impl BoxedUint {
2323
// Repeat enough times to guarantee result has stabilized.
2424
let mut i = 0;
2525
let mut x_prev = x.clone(); // keep the previous iteration in case we need to roll back.
26+
let mut nz_x = NonZero(x.clone());
2627

2728
// TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
2829
while i < self.log2_bits() + 2 {
2930
x_prev.limbs.clone_from_slice(&x.limbs);
3031

3132
// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
32-
33-
let (nz_x, is_nonzero) = (NonZero(x.clone()), x.is_nonzero());
33+
let x_nonzero = x.is_nonzero();
34+
let mut j = 0;
35+
while j < nz_x.0.limbs.len() {
36+
nz_x.0.limbs[j].conditional_assign(&x.limbs[j], x_nonzero);
37+
j += 1;
38+
}
3439
let (q, _) = self.div_rem(&nz_x);
35-
36-
// A protection in case `self == 0`, which will make `x == 0`
37-
let q = Self::ct_select(
38-
&Self::zero_with_precision(self.bits_precision()),
39-
&q,
40-
is_nonzero,
41-
);
42-
43-
x = x.wrapping_add(&q).shr1();
40+
x.conditional_adc_assign(&q, x_nonzero);
41+
x.shr1_assign();
4442
i += 1;
4543
}
4644

0 commit comments

Comments
 (0)