Skip to content

Commit

Permalink
[crypto] Check for multiples of 3, 5, and 17 in RSA keygen.
Browse files Browse the repository at this point in the history
These small primes have the convenient property that 2^8 mod p = 1,
which significantly speeds up the check and allows it to share code with
relprime_f4.

Signed-off-by: Jade Philipoom <[email protected]>
  • Loading branch information
jadephilipoom committed Feb 21, 2025
1 parent 0b02f55 commit 1ba1b59
Show file tree
Hide file tree
Showing 10 changed files with 573 additions and 55 deletions.
276 changes: 221 additions & 55 deletions sw/otbn/crypto/rsa_keygen.s
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
.globl check_p
.globl check_q
.globl modinv_f4
.globl relprime_small_primes

/**
* Generate a random RSA key pair.
Expand Down Expand Up @@ -1003,8 +1004,7 @@ check_p:
/* Get the FG0.Z flag into a register.
x2 <= (CSRs[FG0] >> 3) & 1 = FG0.Z */
csrrs x2, FG0, x0
srli x2, x2, 3
andi x2, x2, 1
andi x2, x2, 8

/* If the flag is set, then the check failed and we can skip the remaining
checks. */
Expand Down Expand Up @@ -1033,13 +1033,6 @@ check_p:
li x2, 256
add x15, x14, x2

/**
* TODO: add something like BoringSSL's is_obviously_composite to filter out
* numbers that are divisible by the first few hundred primes. This filters
* out 80-90% of composites without resorting to the very slow Miller-Rabin
* check.
*/

/* Calculate the number of Miller-Rabin rounds. The number of rounds is
selected based on the bit-length according to FIPS 186-5, table B.1.
According to that table, the minimums for an error probability matching
Expand Down Expand Up @@ -1212,58 +1205,37 @@ generate_prime_candidate:
ret

/**
* Check if a large number is relatively prime to 65537 (aka F4).
*
* Returns a nonzero value if GCD(x,65537) == 1, and 0 otherwise
* Partially reduce a value modulo m such that 2^32 mod m == 1.
*
* A naive implementation would simply check if GCD(x, F4) == 1, However, we
* can simplify the check for relative primality using a few helpful facts
* about F4 specifically:
* 1. It is prime.
* 2. It has the special form (2^16 + 1).
* Returns r such that r mod m = x mod m and r < 2^35.
*
* Because F4 is prime, checking if a number x is relatively prime to F4 means
* simply checking if x is a direct multiple of F4; if (x % F4) != 0, then x is
* relatively prime to F4. This means that instead of computing GCD, we can use
* basic modular arithmetic.
* Can be used to speed up modular reduction on certain numbers.
*
* Here, the special form of F4, fact (2), comes in handy. Note that 2^16 is
* equivalent to -1 modulo F4. So if we express a number x in base-2^16, we can
* simplify as
* follows:
* x = x0 + 2^16 * x1 + 2^32 * x2 + 2^48 * x3 + ...
* x \equiv x0 + (-1) * x1 + (-1)^2 * x2 + (-1)^3 * x3 + ... (mod F4)
* x \equiv x0 - x1 + x2 - x3 + ... (mod F4)
*
* An additionally helpful observation based on fact (2) is that 2^32, 2^64,
* and in general 2^(32*k) for any k are all 1 modulo F4. This includes 2^256,
* so when we receive the input as a bignum in 256-bit limbs, we can simply
* all the limbs together to get an equivalent number modulo F4:
* Because we know 2^32 mod m is 1, it follows that in general 2^(32*k) for any
* k are all 1 modulo m. This includes 2^256, so when we receive the input as
* a bignum in 256-bit limbs, we can simply all the limbs together to get an
* equivalent number modulo m:
* x = x[0] + 2^256 * x[1] + 2^512 * x[2] + ...
* x \equiv x[0] + x[1] + x[2] + ... (mod F4)
*
* From there, we can essentially use the same trick to bisect the number into
* 128-bit, 64-bit, and 32-bit chunks and add these together to get an
* equivalent number modulo F4. For the final 16-bit chunks, we need to
* subtract because 2^16 mod F4 = -1 rather than 1.
* equivalent number modulo m. This operation is visually sort of like folding
* the number over itself repeatedly, which is where the function gets its
* name.
*
* Flags: Flags have no meaning beyond the scope of this subroutine.
*
* @param[in] x16: dptr_x, pointer to first limb of x in dmem
* @param[in] x30: plen, number of 256-bit limbs for x
* @param[in] w24: constant, 2^256 - 1
* @param[in] w31: all-zero
* @param[out] w22: result, 0 only if x is not relatively prime to F4
* @param[out] w23: r, result
*
* clobbered registers: x2, w22, w23
* clobbered flag groups: FG0
*/
relprime_f4:
/* Load F4 into the modulus register for later.
MOD <= 2^16 + 1 */
bn.addi w22, w31, 1
bn.add w22, w22, w22 << 16
bn.wsrw MOD, w22

fold_bignum:
/* Initialize constants for loop. */
li x22, 22

Expand Down Expand Up @@ -1295,8 +1267,7 @@ relprime_f4:

/* Isolate the lower 128 bits of the sum.
w22 <= w23[127:0] */
bn.rshi w22, w23, w31 >> 128
bn.rshi w22, w31, w22 >> 128
bn.and w22, w23, w24 >> 128

/* Add the two 128-bit halves of the sum, plus the carry from the last round
of the sum computation. The sum is now up to 129 bits.
Expand All @@ -1305,8 +1276,7 @@ relprime_f4:

/* Isolate the lower 64 bits of the sum.
w22 <= w23[63:0] */
bn.rshi w22, w23, w31 >> 64
bn.rshi w22, w31, w22 >> 192
bn.and w22, w23, w24 >> 192

/* Add the two halves of the sum (technically 64 and 65 bits). A carry was
not possible in the previous addition since the value is too small. The
Expand All @@ -1316,29 +1286,76 @@ relprime_f4:

/* Isolate the lower 32 bits of the sum.
w22 <= w23[31:0] */
bn.rshi w22, w23, w31 >> 32
bn.rshi w22, w31, w22 >> 224
bn.and w22, w23, w24 >> 224

/* Add the two halves of the sum (technically 32 and 34 bits). A carry was
not possible in the previous addition since the value is too small.
w23 <= (w22 + (w23 >> 32)) */
bn.add w23, w22, w23 >> 32

/* Note: At this point, we're down to the last few terms:
x \equiv (w23[15:0] - w23[31:16] + w23[34:32]) mod F4 */
ret

/**
* Check if a large number is relatively prime to 65537 (aka F4).
*
* Returns a nonzero value if GCD(x,65537) == 1, and 0 otherwise
*
* A naive implementation would simply check if GCD(x, F4) == 1, However, we
* can simplify the check for relative primality using a few helpful facts
* about F4 specifically:
* 1. It is prime.
* 2. It has the special form (2^16 + 1).
*
* Because F4 is prime, checking if a number x is relatively prime to F4 means
* simply checking if x is a direct multiple of F4; if (x % F4) != 0, then x is
* relatively prime to F4. This means that instead of computing GCD, we can use
* basic modular arithmetic.
*
* Here, the special form of F4, fact (2), comes in handy. Since 2^32 mod F4 =
* 1, we can use `fold_bignum` to bring the number down to 35 bits cheaply.
*
* Since 2^16 is equivalent to -1 modulo F4, we can express the resulting
* number in base-2^16 and simplify as follows:
* x = x0 + 2^16 * x1 + 2^32 * x2
* x \equiv x0 + (-1) * x1 + (-1)^2 * x2
* x \equiv x0 - x1 + x2 (mod F4)
*
* Flags: Flags have no meaning beyond the scope of this subroutine.
*
* @param[in] x16: dptr_x, pointer to first limb of x in dmem
* @param[in] x30: n, number of 256-bit limbs for x
* @param[in] w31: all-zero
* @param[out] w22: result, 0 only if x is not relatively prime to F4
*
* clobbered registers: x2, w22, w23
* clobbered flag groups: FG0
*/
relprime_f4:
/* Load F4 into the modulus register for later.
MOD <= 2^16 + 1 */
bn.addi w22, w31, 1
bn.add w22, w22, w22 << 16
bn.wsrw MOD, w22

/* Generate a 256-bit mask.
w24 <= 2^256 - 1 */
bn.not w24, w31

/* Fold the bignum to get a 35-bit number r such that r mod F4 = x mod F4.
w23 <= r */
jal x1, fold_bignum

/* Isolate the lower 16 bits of the 35-bit working sum.
w22 <= w23[15:0] */
bn.rshi w22, w23, w31 >> 16
bn.rshi w22, w31, w22 >> 240
bn.and w22, w23, w24 >> 240

/* Add the lower 16 bits of the sum to the highest 3 bits to get a 17-bit
result.
w22 <= w22 + (w23 >> 32) */
bn.add w22, w22, w23 >> 32

/* The sum from the previous addition is < 2 * F4, so a modular addition with
zero is sufficient to fully reduce.
/* The sum from the previous addition is <= 2^16 - 1 + 2^3 - 1 < 2 * F4, so a
modular addition with zero is sufficient to fully reduce.
w22 <= w22 mod F4 */
bn.addm w22, w22, w31

Expand All @@ -1353,6 +1370,155 @@ relprime_f4:

ret

/**
* Check if a large number is divisible by a few small primes.
*
* Returns 0 if x is divisible by a small prime, 2^256 - 1 otherwise.
*
* In this implementation we specifically check the primes 3, 8, and 17, since
* these values m have the property that (1 << 8) mod m is 1, so we can use
* `fold_bignum` to check them very quickly. We use `fold_bignum` to get a
* 35-bit result and then fold the number a few more times to get a 9-bit
* result.
*
* Testing
*
* This routine is constant-time relative to x if x is not divisible by any
* small primes, but exits early if it finds that x is divisible by a small
* prime.
*
* Flags: Flags have no meaning beyond the scope of this subroutine.
*
* @param[in] x16: dptr_x, pointer to first limb of x in dmem
* @param[in] x30: n, number of 256-bit limbs for x
* @param[in] w31: all-zero
* @param[out] w22: result, 0 if x is divisible by a small prime
*
* clobbered registers: x2, w22, w23, w24, w25
* clobbered flag groups: FG0
*/
relprime_small_primes:
/* Generate a 256-bit mask.
w24 <= 2^256 - 1 */
bn.not w24, w31

/* Fold the bignum to get a 35-bit number r such that r mod F4 = x mod F4.
w23 <= r */
jal x1, fold_bignum

/* Isolate the lower 16 bits of the 35-bit working sum.
w22 <= w23[15:0] */
bn.and w22, w23, w24 >> 240

/* Add the lower 16 bits to the higher 19 bits to get a 20-bit result.
w23 <= w22 + (w23 >> 16) */
bn.add w23, w22, w23 >> 16

/* Isolate the lower 8 bits of the 20-bit working sum.
w22 <= w23[7:0] */
bn.and w22, w23, w24 >> 248

/* Add the lower 8 bits to the higher 12 bits to get a 13-bit result.
w23 <= w22 + (w23 >> 8) */
bn.add w23, w22, w23 >> 8

/* Isolate the lower 8 bits of the 13-bit working sum.
w22 <= w23[7:0] */
bn.and w22, w23, w24 >> 248

/* Add the lower 8 bits to the higher 5 bits to get a 9-bit result.
w23 <= w22 + (w23 >> 8) */
bn.add w23, w22, w23 >> 8

/* Check the residue modulo 3.
x2 <= if (w23 mod 3) == 0 then 8 else 0 */
bn.mov w25, w23
bn.addi w24, w31, 3
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
bne x2, x0, __relprime_small_primes_fail

/* Check the residue modulo 5.
x2 <= if (w23 mod 5) == 0 then 8 else 0 */
bn.mov w25, w23
bn.addi w24, w31, 5
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
bne x2, x0, __relprime_small_primes_fail

/* Check the residue modulo 17.
x2 <= if (w23 mod 17) == 0 then 8 else 0 */
bn.mov w25, w23
bn.addi w24, w31, 17
jal x1, is_zero_mod_small_prime

/* If x2 != 0, exit early. */
bne x2, x0, __relprime_small_primes_fail

/* No small prime divisors found; return 2^256 - 1. */
bn.not w22, w31
ret

__relprime_small_primes_fail:
/* Small prime divisor found; return 0. */
bn.sub w22, w22, w22
ret

/**
* Reduce a 9-bit number modulo a small number with conditional subtractions.
*
* Returns r = 8 if x mod m = 0, otherwise r = 0.
*
* Helper function for `relprime_small_primes`.
*
* This function runs in constant time.
*
* @param[in] w23: x, input, x < 2^9.
* @param[in] w24: m, modulus, 2 < m < 2^256.
* @param[in] w31: all-zero
* @param[out] x2: result, 8 if x mod m = 0 and otherwise 0
*
* clobbered registers: x2, w24, w25
* clobbered flag groups: FG0
*/
is_zero_mod_small_prime:
/* Copy input. */
bn.mov w25, w23

/* Since we know m is at least 2 bits, start with m << 7 as a value that will
definitely be greater than x / 2.
w24 <= m << 7 */
bn.rshi w24, w24, w31 >> 249

/* Repeatedly reduce using conditional subtractions.
Loop invariant (i=7 to 0):
w24 = m << i
w25 < 2*(m << i)
w25 mod m = x mod m
*/
loopi 8, 3
/* w22 <= w25 - w24 */
bn.sub w22, w25, w24
/* Select the subtraction only if it did not underflow.
w25 <= FG0.C ? w25 : w22 */
bn.sel w25, w25, w22, FG0.C
/* w24 <= w24 >> 1 */
bn.rshi w24, w31, w24 >> 1

/* Check if w25 is 0.
FG0.Z <= w25 == 0 */
bn.cmp w25, w31

/* Get the FG0.Z flag into a register and return.
x2 <= FG0 & 8 = FG0.Z << 3 */
csrrs x2, 0x7c0, x0
andi x2, x2, 8

ret

.section .scratchpad

/* Extra label marking the start of p || q in memory. The `derive_d` function
Expand Down
Loading

0 comments on commit 1ba1b59

Please sign in to comment.