Skip to content

Commit

Permalink
[crypto] Check for multiples of 7, 11, and 31 in RSA keygen.
Browse files Browse the repository at this point in the history
All of these small primes have the nice property that 2^32 mod p = 4.

Signed-off-by: Jade Philipoom <[email protected]>
  • Loading branch information
jadephilipoom committed Jan 29, 2024
1 parent 3cfe5c7 commit 50891e0
Show file tree
Hide file tree
Showing 11 changed files with 578 additions and 138 deletions.
249 changes: 213 additions & 36 deletions sw/otbn/crypto/rsa_keygen.s
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,8 @@ generate_prime_candidate:
*
* Returns r such that r mod m = x mod m and r < 2^35.
*
* Can be used to speed up modular reduction on certain numbers.
* Can be used to speed up modular reduction on certain numbers, such as 3, 5,
* 17, and 65537.
*
* Because we know 2^32 mod m is 1, it follows that in general 2^(32*k) for any
* k are all 1 modulo m. This includes 2^256, so when we receive the input as
Expand Down Expand Up @@ -1252,15 +1253,15 @@ fold_bignum:
Loop invariants for iteration i (i=0..n-1):
x2 = dptr_x + i*32
x22 = 22
(w23 + FG0.C) \equiv x[0] + x[1] + ... + x[i-1] (mod F4)
(w23 + FG0.C) \equiv x[0] + x[1] + ... + x[i-1] (mod m)
*/
loop x30, 2
/* Load the next limb.
w22 <= x[i] */
bn.lid x22, 0(x2++)

/* Accumulate the new limb, incorporating the carry bit from the previous
round if there was one (this works because 2^256 \equiv 1 mod F4).
round if there was one (this works because 2^256 \equiv 1 mod m).
w23 <= (w23 + x[i] + FG0.C) mod 2^256
FG0.C <= (w23 + x[i] + FG0.C) / 2^256 */
bn.addc w23, w23, w22
Expand Down Expand Up @@ -1295,6 +1296,129 @@ fold_bignum:

ret

/**
* Partially reduce a value modulo m such that 2^32 mod m == 4.
*
* Returns r such that r mod m = x mod m and r < 2^33.
*
* Can be used to speed up modular reduction on certain numbers, such as 7, 11,
* and 31.
*
* The logic here is very similar to `fold_bignum`, except we need to multiply
* by a power of 4 each time we fold. The core reasoning is that, for any
* positive k:
* x0 + 2^(32*k)x1 \equiv x0 + (4**k)*x1 (mod m)
*
* This routine assumes that the number of limbs `n` is at most 8 (i.e. enough
* for RSA-4096); bounds analysis may not work out for larger numbers.
*
* Flags: Flags have no meaning beyond the scope of this subroutine.
*
* @param[in] x16: dptr_x, pointer to first limb of x in dmem
* @param[in] x30: n, number of 256-bit limbs for x, 1 <= n < 8
* @param[in] w24: constant, 2^256 - 1
* @param[in] w31: all-zero
* @param[out] w23: r, result
*
* clobbered registers: x2, x3, w22, w23, w25
* clobbered flag groups: FG0
*/
fold_bignum_pow2_32_equiv_4:
/* Initialize constants for loop. */
li x3, 32
li x22, 22

/* Get a pointer to the end of the input.
x2 <= dptr_x + n*32 */
slli x2, x30, 5
add x2, x2, x16

/* Initialize the two-limb sum to zero and clear FG0.C.
w25, w23 <= 0
FG0.C <= 0 */
bn.sub w23, w23, w23
bn.sub w25, w25, w25

/* Iterate through the limbs of x and add them together.
We shift by 16 each time, since 2^256 mod m = 4**8 = 2^16. The size of the
sum therefore increases by 17 bits on each iteration (16 from the shift
and 1 from the addition). Since we are assuming at most 8 limbs, the
maximum value of the final sum should fit in 256+8*17 = 375 bits.
Loop invariants for iteration i (i=0..n-1):
x2 = dptr_x + (n-i)*32
x3 = 32
x22 = 22
(w23 + (w25 << 256)) \equiv x[i+1] + x[i+2] + ... + x[n-1] (mod m)
(w23 + (w25 << 256)) < 2^(256+(n-1-i)*17)
*/
loop x30, 5
/* Move the pointer down one limb.
x2 <= dptr_x + (n-1-i)*32 */
sub x2, x2, x3

/* Load the next limb.
w22 <= x[n-1-i] */
bn.lid x22, 0(x2)

/* Get the high part of the shifted sum.
w25 <= ([w25,w23] << 16) >> 256 */
bn.rshi w25, w25, w23 >> 240

/* Accumulate the new limb.
[w25,w23] <= ([w25,w23] << 16 + x[i]) mod 2^256 */
bn.add w23, w22, w23 << 16
bn.addc w25, w31, w25

/* Add the two limbs of the sum for a 257-bit result.
w23 + (FG0.C << 256) <= w23 + (w25 << 16) */
bn.add w23, w23, w25 << 16

/* Add the carry bit to the high 128 bits of the sum.
w25 <= (w23 >> 128) + FG0.C */
bn.addc w25, w31, w23 >> 128

/* Isolate the lower 128 bits of the sum.
w22 <= w23[127:0] */
bn.and w22, w23, w24 >> 128

/* Add the two halves of the sum to get a 129+8+1=138-bit value.
w23 <= w22 + w25 << 8 */
bn.addc w23, w22, w25 << 8

/* Isolate the lower 64 bits of the sum.
w22 <= w23[63:0] */
bn.and w22, w23, w24 >> 192

/* Add the two halves of the sum to get a (138-64)+4+1=79-bit value.
w23 <= (w22 + ((w23 >> 64) << 4)) */
bn.rshi w23, w31, w23 >> 64
bn.rshi w23, w23, w31 >> 252
bn.add w23, w22, w23

/* Isolate the lower 32 bits of the sum.
w22 <= w23[31:0] */
bn.and w22, w23, w24 >> 224

/* Add the two halves of the sum to get a (79-32)+2+1=50-bit value.
w23 <= (w22 + ((w23 >> 32) << 2)) */
bn.rshi w23, w31, w23 >> 32
bn.rshi w23, w23, w31 >> 254
bn.add w23, w22, w23

/* Isolate the lower 32 bits of the sum again.
w22 <= w23[31:0] */
bn.and w22, w23, w24 >> 224

/* Add the two halves of the sum to get a 33-bit value.
w23 <= (w22 + ((w23 >> 32) << 2)) */
bn.rshi w23, w31, w23 >> 32
bn.rshi w23, w23, w31 >> 254
bn.add w23, w22, w23

ret

/**
* Check if a large number is relatively prime to 65537 (aka F4).
*
Expand Down Expand Up @@ -1375,13 +1499,21 @@ relprime_f4:
*
* Returns 0 if x is divisible by a small prime, 2^256 - 1 otherwise.
*
* In this implementation we specifically check the primes 3, 8, and 17, since
* these values m have the property that (1 << 8) mod m is 1, so we can use
* `fold_bignum` to check them very quickly. We use `fold_bignum` to get a
* 35-bit result and then fold the number a few more times to get a 9-bit
* result.
* In this implementation, we check the primes 3, 5, 7, 11, 17, and 31.
* These primes have special properties that allow us to compute the residue
* quickly:
* - p = {3,5,17} have the property that (2^8) mod p = 1
* - p = {7,11,31} have the property that (2^32) mod p = 4
*
* Testing
* Testing for these primes will catch approximately:
* 1 - ((1 - 1/3) * (1 - 1/5) * ... * (1 - 1/31))
* = 62.1% of composite numbers.
*
* Quick intuition for the estimate above: the multiplications calculate the
* proportion of composites we will *fail* to catch. At each multiplication
* step, we multiply the proportion of composites we still haven't caught by
* the proportion that the next small prime will *not* catch (e.g. 4/5 of
* numbers will not be multiples of 5).
*
* This routine is constant-time relative to x if x is not divisible by any
* small primes, but exits early if it finds that x is divisible by a small
Expand All @@ -1394,15 +1526,16 @@ relprime_f4:
* @param[in] w31: all-zero
* @param[out] w22: result, 0 if x is divisible by a small prime
*
* clobbered registers: x2, w22, w23, w24, w25
* clobbered registers: x2, w22, w23, w24, w25, w26
* clobbered flag groups: FG0
*/
relprime_small_primes:
/* Generate a 256-bit mask.
w24 <= 2^256 - 1 */
bn.not w24, w31

/* Fold the bignum to get a 35-bit number r such that r mod F4 = x mod F4.
/* Fold the bignum to get a 35-bit number r such that r mod m = x mod m for
all m such that 2^32 mod m == 1.
w23 <= r */
jal x1, fold_bignum

Expand Down Expand Up @@ -1430,28 +1563,64 @@ relprime_small_primes:
w23 <= w22 + (w23 >> 8) */
bn.add w23, w22, w23 >> 8

/* Load the bit-length for `is_zero_mod_small_prime`. */
li x10, 9

/* Check the residue modulo 3.
x2 <= if (w23 mod 3) == 0 then 8 else 0 */
bn.mov w25, w23
bn.addi w24, w31, 3
bn.addi w26, w31, 3
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
bne x2, x0, __relprime_small_primes_fail

/* Check the residue modulo 5.
x2 <= if (w23 mod 5) == 0 then 8 else 0 */
bn.mov w25, w23
bn.addi w24, w31, 5
bn.addi w26, w31, 5
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
bne x2, x0, __relprime_small_primes_fail

/* Check the residue modulo 17.
x2 <= if (w23 mod 17) == 0 then 8 else 0 */
bn.mov w25, w23
bn.addi w24, w31, 17
bn.addi w26, w31, 17
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
bne x2, x0, __relprime_small_primes_fail

/* We didn't find any divisors among the primes p such that 2^8 mod p == 1;
now let's try primes such that 2^32 mod p == 4. This includes 7, 11, and
31. */

/* Fold the bignum to get a 33-bit number r such that r mod m = x mod m for
all m such that 2^32 mod m == 4.
w23 <= r */
jal x1, fold_bignum_pow2_32_equiv_4

/* Load the bit-length for `is_zero_mod_small_prime`. */
li x10, 33

/* Check the residue modulo 7.
x2 <= if (w23 mod 7) == 0 then 8 else 0 */
bn.addi w26, w31, 7
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
bne x2, x0, __relprime_small_primes_fail

/* Check the residue modulo 11.
x2 <= if (w23 mod 5) == 0 then 8 else 0 */
bn.addi w26, w31, 11
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
bne x2, x0, __relprime_small_primes_fail

/* Check the residue modulo 31.
x2 <= if (w23 mod 17) == 0 then 8 else 0 */
bn.addi w26, w31, 31
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
Expand All @@ -1467,46 +1636,54 @@ __relprime_small_primes_fail:
ret

/**
* Reduce a 9-bit number modulo a small number with conditional subtractions.
* Reduce input modulo a small number with conditional subtractions.
*
* Returns r = 8 if a mod m = 0, otherwise r = 0.
*
* Returns r = 8 if x mod m = 0, otherwise r = 0.
* Helper function for `relprime_small_primes`. This routine takes time linear
* in the number of bits of the input, so it's slow for large numbers and
* should only be used as a last step once the bit-bound is low.
*
* Helper function for `relprime_small_primes`.
* The sum of the bit-length of the input and the modulus should not exceed
* 256.
*
* This function runs in constant time.
*
* @param[in] w23: x, input, x < 2^9.
* @param[in] w24: m, modulus, 2 < m < 2^256.
* @param[in] x10: len, max. number of bits in input, 1 < len
* @param[in] w23: a, input, a < 2^len.
* @param[in] w26: m, modulus, 2 < m < 2^(256-len).
* @param[in] w31: all-zero
* @param[out] x2: result, 8 if x mod m = 0 and otherwise 0
* @param[out] x2: result, 8 if a mod m = 0 and otherwise 0
*
* clobbered registers: x2, w24, w25
* clobbered registers: x2, w22, w25, w26
* clobbered flag groups: FG0
*/
is_zero_mod_small_prime:
/* Copy input. */
bn.mov w25, w23

/* Since we know m is at least 2 bits, start with m << 7 as a value that will
definitely be greater than x / 2.
w24 <= m << 7 */
bn.rshi w24, w24, w31 >> 249
/* Initialize shifted modulus for loop.
w26 <= m << (len - 1) */
li x2, 1
sub x2, x10, x2
loop x2, 1
bn.add w26, w26, w26

/* Repeatedly reduce using conditional subtractions.
Loop invariant (i=7 to 0):
w24 = m << i
Loop invariant (i=len-1 to 0):
w26 = m << i
w25 < 2*(m << i)
w25 mod m = x mod m
w25 mod m = a mod m
*/
loopi 8, 3
/* w22 <= w25 - w24 */
bn.sub w22, w25, w24
loop x10, 3
/* w22 <= w25 - w26 */
bn.sub w22, w25, w26
/* Select the subtraction only if it did not underflow.
w25 <= FG0.C ? w25 : w22 */
bn.sel w25, w25, w22, FG0.C
/* w24 <= w24 >> 1 */
bn.rshi w24, w31, w24 >> 1
/* w26 <= w26 >> 1 */
bn.rshi w26, w31, w26 >> 1

/* Check if w25 is 0.
FG0.Z <= w25 == 0 */
Expand Down
Loading

0 comments on commit 50891e0

Please sign in to comment.