diff --git a/dsa/src/generate/secret_number.rs b/dsa/src/generate/secret_number.rs index 2009dba1..218bb574 100644 --- a/dsa/src/generate/secret_number.rs +++ b/dsa/src/generate/secret_number.rs @@ -3,30 +3,14 @@ //! use crate::{Components, SigningKey}; -use alloc::{vec, vec::Vec}; -use core::cmp::min; -use crypto_bigint::{BoxedUint, InvMod, NonZero, RandomBits}; +use alloc::vec; +use crypto_bigint::{BoxedUint, Integer, InvMod, NonZero, RandomBits}; use digest::{core_api::BlockSizeUser, Digest, FixedOutputReset}; -use rfc6979::HmacDrbg; use signature::rand_core::CryptoRngCore; -use zeroize::Zeroize; +use zeroize::Zeroizing; -/// Reduce the hash into an RFC-6979 appropriate form -fn reduce_hash(q: &NonZero, hash: &[u8]) -> Vec { - // Reduce the hash modulo Q - let q_byte_len = q.bits() / 8; - - let hash_len = min(hash.len(), q_byte_len as usize); - let hash = &hash[..hash_len]; - - let hash = BoxedUint::from_be_slice(hash, (hash.len() * 8) as u32).unwrap(); - let mut reduced = Vec::from((hash % q).to_be_bytes()); - - while reduced.len() < q_byte_len as usize { - reduced.insert(0, 0); - } - - reduced +fn strip_leading_zeros(buffer: &[u8], desired_size: usize) -> &[u8] { + &buffer[(buffer.len() - desired_size)..] } /// Generate a per-message secret number k deterministically using the method described in RFC 6979 @@ -40,20 +24,27 @@ where D: Digest + BlockSizeUser + FixedOutputReset, { let q = signing_key.verifying_key().components().q(); - let k_size = (q.bits() / 8) as usize; - let hash = reduce_hash(q, hash); + let size = (q.bits() / 8) as usize; + + // Reduce hash mod q + let hash = BoxedUint::from_be_slice(hash, (hash.len() * 8) as u32).unwrap(); + let hash = (hash % q).to_be_bytes(); + let hash = strip_leading_zeros(&hash, size); + + let q_bytes = q.to_be_bytes(); + let q_bytes = strip_leading_zeros(&q_bytes, size); - let mut x_bytes = signing_key.x().to_be_bytes(); - let mut hmac = HmacDrbg::::new(&x_bytes, &hash, &[]); - x_bytes.zeroize(); + let x_bytes = Zeroizing::new(signing_key.x().to_be_bytes()); + let x_bytes = strip_leading_zeros(&x_bytes, size); - let mut buffer = vec![0; k_size]; + let mut buffer = vec![0; size]; loop { - hmac.fill_bytes(&mut buffer); + rfc6979::generate_k_mut::(x_bytes, q_bytes, hash, &[], &mut buffer); let k = BoxedUint::from_be_slice(&buffer, (buffer.len() * 8) as u32).unwrap(); if let Some(inv_k) = k.inv_mod(q).into() { - if k > BoxedUint::zero() && k < **q { + if (bool::from(k.is_nonzero())) & (k < **q) { + debug_assert!(bool::from(k.is_odd())); return (k, inv_k); } }