Skip to content

Commit

Permalink
Cache random (#1640)
Browse files Browse the repository at this point in the history
* Avoid allocation in random()

This makes the function take its argument as a const generic argument,
which allows the allocation to be performed on the stack.

There are a few cases where we need to do in-place randomization of a
different type of object (`SmallVec` in a few places) so I've also
exposed an in-place mutation function.

Next step is to cache the slot that this uses.

* Cache randomness

* Add a test for the cache

* Remove dead code

---------

Co-authored-by: Lars Eggert <[email protected]>
  • Loading branch information
martinthomson and larseggert authored Feb 9, 2024
1 parent daa9394 commit bb74821
Show file tree
Hide file tree
Showing 16 changed files with 136 additions and 56 deletions.
21 changes: 15 additions & 6 deletions neqo-crypto/src/hkdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ use crate::{
},
err::{Error, Res},
p11::{
random, Item, PK11Origin, PK11SymKey, PK11_ImportDataKey, Slot, SymKey, CKA_DERIVE,
Item, PK11Origin, PK11SymKey, PK11_ImportDataKey, Slot, SymKey, CKA_DERIVE,
CKM_HKDF_DERIVE, CK_ATTRIBUTE_TYPE, CK_MECHANISM_TYPE,
},
random,
};

experimental_api!(SSL_HkdfExtract(
Expand All @@ -40,24 +41,32 @@ experimental_api!(SSL_HkdfExpandLabel(
secret: *mut *mut PK11SymKey,
));

fn key_size(version: Version, cipher: Cipher) -> Res<usize> {
const MAX_KEY_SIZE: usize = 48;
const fn key_size(version: Version, cipher: Cipher) -> Res<usize> {
if version != TLS_VERSION_1_3 {
return Err(Error::UnsupportedVersion);
}
Ok(match cipher {
let size = match cipher {
TLS_AES_128_GCM_SHA256 | TLS_CHACHA20_POLY1305_SHA256 => 32,
TLS_AES_256_GCM_SHA384 => 48,
_ => return Err(Error::UnsupportedCipher),
})
};
debug_assert!(size <= MAX_KEY_SIZE);
Ok(size)
}

/// Generate a random key of the right size for the given suite.
///
/// # Errors
///
/// Only if NSS fails.
/// If the ciphersuite or protocol version is not supported.
pub fn generate_key(version: Version, cipher: Cipher) -> Res<SymKey> {
import_key(version, &random(key_size(version, cipher)?))
// With generic_const_expr, this becomes:
// import_key(version, &random::<{ key_size(version, cipher) }>())
import_key(
version,
&random::<MAX_KEY_SIZE>()[0..key_size(version, cipher)?],
)
}

/// Import a symmetric key for use with HKDF.
Expand Down
2 changes: 1 addition & 1 deletion neqo-crypto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub use self::{
},
err::{Error, PRErrorCode, Res},
ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult},
p11::{random, PrivateKey, PublicKey, SymKey},
p11::{random, randomize, PrivateKey, PublicKey, SymKey},
replay::AntiReplay,
secrets::SecretDirection,
ssl::Opt,
Expand Down
99 changes: 88 additions & 11 deletions neqo-crypto/src/p11.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#![allow(non_snake_case)]

use std::{
cell::RefCell,
convert::TryFrom,
mem,
ops::{Deref, DerefMut},
Expand Down Expand Up @@ -289,31 +290,107 @@ impl Item {
}
}

/// Generate a randomized buffer.
/// Fill a buffer with randomness.
///
/// # Panics
///
/// When `size` is too large or NSS fails.
#[must_use]
pub fn random(size: usize) -> Vec<u8> {
let mut buf = vec![0; size];
secstatus_to_res(unsafe {
PK11_GenerateRandom(buf.as_mut_ptr(), c_int::try_from(buf.len()).unwrap())
})
.unwrap();
pub fn randomize<B: AsMut<[u8]>>(mut buf: B) -> B {
let m_buf = buf.as_mut();
let len = c_int::try_from(m_buf.len()).unwrap();
secstatus_to_res(unsafe { PK11_GenerateRandom(m_buf.as_mut_ptr(), len) }).unwrap();
buf
}

struct RandomCache {
cache: [u8; Self::SIZE],
used: usize,
}

impl RandomCache {
const SIZE: usize = 256;
const CUTOFF: usize = 32;

fn new() -> Self {
RandomCache {
cache: [0; Self::SIZE],
used: Self::SIZE,
}
}

fn randomize<B: AsMut<[u8]>>(&mut self, mut buf: B) -> B {
let m_buf = buf.as_mut();
debug_assert!(m_buf.len() <= Self::CUTOFF);
let avail = Self::SIZE - self.used;
if m_buf.len() <= avail {
m_buf.copy_from_slice(&self.cache[self.used..self.used + m_buf.len()]);
self.used += m_buf.len();
} else {
if avail > 0 {
m_buf[..avail].copy_from_slice(&self.cache[self.used..]);
}
randomize(&mut self.cache[..]);
self.used = m_buf.len() - avail;
m_buf[avail..].copy_from_slice(&self.cache[..self.used]);
}
buf
}
}

/// Generate a randomized array.
///
/// # Panics
///
/// When `size` is too large or NSS fails.
#[must_use]
pub fn random<const N: usize>() -> [u8; N] {
thread_local! { static CACHE: RefCell<RandomCache> = RefCell::new(RandomCache::new()) };

let buf = [0; N];
if N <= RandomCache::CUTOFF {
CACHE.with_borrow_mut(|c| c.randomize(buf))
} else {
randomize(buf)
}
}

#[cfg(test)]
mod test {
use test_fixture::fixture_init;

use super::random;
use super::RandomCache;
use crate::{random, randomize};

#[test]
fn randomness() {
fixture_init();
// If this ever fails, there is either a bug, or it's time to buy a lottery ticket.
assert_ne!(random(16), random(16));
// If any of these ever fail, there is either a bug, or it's time to buy a lottery ticket.
assert_ne!(random::<16>(), randomize([0; 16]));
assert_ne!([0; 16], random::<16>());
assert_ne!([0; 64], random::<64>());
}

#[test]
fn cache_random_lengths() {
const ZERO: [u8; 256] = [0; 256];

fixture_init();
let mut cache = RandomCache::new();
let mut buf = [0; 256];
let bits = usize::BITS - (RandomCache::CUTOFF - 1).leading_zeros();
let mask = 0xff >> (u8::BITS - bits);

for _ in 0..100 {
let len = loop {
let len = usize::from(random::<1>()[0] & mask) + 1;
if len <= RandomCache::CUTOFF {
break len;
}
};
buf.fill(0);
if len >= 16 {
assert_ne!(&cache.randomize(&mut buf[..len])[..len], &ZERO[..len]);
}
}
}
}
2 changes: 1 addition & 1 deletion neqo-crypto/src/selfencrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl SelfEncrypt {
// opaque aead_encrypted(plaintext)[length as expanded];
// };
// AAD covers the entire header, plus the value of the AAD parameter that is provided.
let salt = random(Self::SALT_LENGTH);
let salt = random::<{ Self::SALT_LENGTH }>();
let cipher = self.make_aead(&self.key, &salt)?;
let encoded_len = 2 + salt.len() + plaintext.len() + cipher.expansion();

Expand Down
7 changes: 2 additions & 5 deletions neqo-http3/src/frames/hframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ impl HFrame {
Self::MaxPushId { .. } => H3_FRAME_TYPE_MAX_PUSH_ID,
Self::PriorityUpdateRequest { .. } => H3_FRAME_TYPE_PRIORITY_UPDATE_REQUEST,
Self::PriorityUpdatePush { .. } => H3_FRAME_TYPE_PRIORITY_UPDATE_PUSH,
Self::Grease => {
let r = random(7);
Decoder::from(&r).decode_uint(7).unwrap() * 0x1f + 0x21
}
Self::Grease => Decoder::from(&random::<7>()).decode_uint(7).unwrap() * 0x1f + 0x21,
}
}

Expand Down Expand Up @@ -120,7 +117,7 @@ impl HFrame {
}
Self::Grease => {
// Encode some number of random bytes.
let r = random(8);
let r = random::<8>();
enc.encode_vvec(&r[1..usize::from(1 + (r[0] & 0x7))]);
}
Self::PriorityUpdateRequest {
Expand Down
2 changes: 1 addition & 1 deletion neqo-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ impl HttpServer for SimpleServer {
fn enable_ech(&mut self) -> &[u8] {
let (sk, pk) = generate_ech_keys().expect("should create ECH keys");
self.server
.enable_ech(random(1)[0], "public.example", &sk, &pk)
.enable_ech(random::<1>()[0], "public.example", &sk, &pk)
.unwrap();
self.server.ech_config()
}
Expand Down
2 changes: 1 addition & 1 deletion neqo-server/src/old_https.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl HttpServer for Http09Server {
fn enable_ech(&mut self) -> &[u8] {
let (sk, pk) = generate_ech_keys().expect("generate ECH keys");
self.server
.enable_ech(random(1)[0], "public.example", &sk, &pk)
.enable_ech(random::<1>()[0], "public.example", &sk, &pk)
.expect("enable ECH");
self.server.ech_config()
}
Expand Down
30 changes: 14 additions & 16 deletions neqo-transport/src/cid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use std::{
};

use neqo_common::{hex, hex_with_len, qinfo, Decoder, Encoder};
use neqo_crypto::random;
use smallvec::SmallVec;
use neqo_crypto::{random, randomize};
use smallvec::{smallvec, SmallVec};

use crate::{
frame::FRAME_TYPE_NEW_CONNECTION_ID, packet::PacketBuilder, recovery::RecoveryToken,
Expand All @@ -41,14 +41,16 @@ pub struct ConnectionId {
impl ConnectionId {
pub fn generate(len: usize) -> Self {
assert!(matches!(len, 0..=MAX_CONNECTION_ID_LEN));
Self::from(random(len))
let mut cid = smallvec![0; len];
randomize(&mut cid);
Self { cid }
}

// Apply a wee bit of greasing here in picking a length between 8 and 20 bytes long.
pub fn generate_initial() -> Self {
let v = random(1);
let v = random::<1>()[0];
// Bias selection toward picking 8 (>50% of the time).
let len: usize = max(8, 5 + (v[0] & (v[0] >> 4))).into();
let len: usize = max(8, 5 + (v & (v >> 4))).into();
Self::generate(len)
}

Expand All @@ -75,12 +77,6 @@ impl From<SmallVec<[u8; MAX_CONNECTION_ID_LEN]>> for ConnectionId {
}
}

impl From<Vec<u8>> for ConnectionId {
fn from(cid: Vec<u8>) -> Self {
Self::from(SmallVec::from(cid))
}
}

impl<T: AsRef<[u8]> + ?Sized> From<&T> for ConnectionId {
fn from(buf: &T) -> Self {
Self::from(SmallVec::from(buf.as_ref()))
Expand Down Expand Up @@ -222,7 +218,9 @@ impl ConnectionIdDecoder for RandomConnectionIdGenerator {

impl ConnectionIdGenerator for RandomConnectionIdGenerator {
fn generate_cid(&mut self) -> Option<ConnectionId> {
Some(ConnectionId::from(&random(self.len)))
let mut buf = smallvec![0; self.len];
randomize(&mut buf);
Some(ConnectionId::from(buf))
}

fn as_decoder(&self) -> &dyn ConnectionIdDecoder {
Expand Down Expand Up @@ -250,8 +248,8 @@ pub struct ConnectionIdEntry<SRT: Clone + PartialEq> {
impl ConnectionIdEntry<[u8; 16]> {
/// Create a random stateless reset token so that it is hard to guess the correct
/// value and reset the connection.
fn random_srt() -> [u8; 16] {
<[u8; 16]>::try_from(&random(16)[..]).unwrap()
pub fn random_srt() -> [u8; 16] {
random::<16>()
}

/// Create the first entry, which won't have a stateless reset token.
Expand Down Expand Up @@ -476,7 +474,7 @@ impl ConnectionIdManager {
.add_local(ConnectionIdEntry::new(self.next_seqno, cid.clone(), ()));
self.next_seqno += 1;

let srt = <[u8; 16]>::try_from(&random(16)[..]).unwrap();
let srt = ConnectionIdEntry::random_srt();
Ok((cid, srt))
} else {
Err(Error::ConnectionIdsExhausted)
Expand Down Expand Up @@ -565,7 +563,7 @@ impl ConnectionIdManager {
if let Some(cid) = maybe_cid {
assert_ne!(cid.len(), 0);
// TODO: generate the stateless reset tokens from the connection ID and a key.
let srt = <[u8; 16]>::try_from(&random(16)[..]).unwrap();
let srt = ConnectionIdEntry::random_srt();

let seqno = self.next_seqno;
self.next_seqno += 1;
Expand Down
4 changes: 2 additions & 2 deletions neqo-transport/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use neqo_common::{
qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, Role,
};
use neqo_crypto::{
agent::CertificateInfo, random, Agent, AntiReplay, AuthenticationStatus, Cipher, Client, Group,
agent::CertificateInfo, Agent, AntiReplay, AuthenticationStatus, Cipher, Client, Group,
HandshakeState, PrivateKey, PublicKey, ResumptionToken, SecretAgentInfo, SecretAgentPreInfo,
Server, ZeroRttChecker,
};
Expand Down Expand Up @@ -2405,7 +2405,7 @@ impl Connection {
} else {
// The other side didn't provide a stateless reset token.
// That's OK, they can try guessing this.
<[u8; 16]>::try_from(&random(16)[..]).unwrap()
ConnectionIdEntry::random_srt()
};
self.paths
.primary()
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/connection/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl ConnectionIdDecoder for CountingConnectionIdGenerator {

impl ConnectionIdGenerator for CountingConnectionIdGenerator {
fn generate_cid(&mut self) -> Option<ConnectionId> {
let mut r = random(20);
let mut r = random::<20>();
r[0] = 8;
r[1] = u8::try_from(self.counter >> 24).unwrap();
r[2] = u8::try_from((self.counter >> 16) & 0xff).unwrap();
Expand Down
6 changes: 3 additions & 3 deletions neqo-transport/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ impl PacketBuilder {
let mask = if quic_bit { PACKET_BIT_FIXED_QUIC } else { 0 }
| if self.is_long() { 0 } else { PACKET_BIT_SPIN };
let first = self.header.start;
self.encoder.as_mut()[first] ^= random(1)[0] & mask;
self.encoder.as_mut()[first] ^= random::<1>()[0] & mask;
}

/// For an Initial packet, encode the token.
Expand Down Expand Up @@ -424,7 +424,7 @@ impl PacketBuilder {
PACKET_BIT_LONG
| PACKET_BIT_FIXED_QUIC
| (PacketType::Retry.to_byte(version) << 4)
| (random(1)[0] & 0xf),
| (random::<1>()[0] & 0xf),
);
encoder.encode_uint(4, version.wire_version());
encoder.encode_vec(1, dcid);
Expand All @@ -448,7 +448,7 @@ impl PacketBuilder {
versions: &[Version],
) -> Vec<u8> {
let mut encoder = Encoder::default();
let mut grease = random(4);
let mut grease = random::<4>();
// This will not include the "QUIC bit" sometimes. Intentionally.
encoder.encode_byte(PACKET_BIT_LONG | (grease[3] & 0x7f));
encoder.encode(&[0; 4]); // Zero version == VN.
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ impl Path {
// Send PATH_CHALLENGE.
if let ProbeState::ProbeNeeded { probe_count } = self.state {
qtrace!([self], "Initiating path challenge {}", probe_count);
let data = <[u8; 8]>::try_from(&random(8)[..]).unwrap();
let data = random::<8>();
builder.encode_varint(FRAME_TYPE_PATH_CHALLENGE);
builder.encode(&data);

Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/tparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ impl TransportParameters {

/// Set version information.
pub fn set_versions(&mut self, role: Role, versions: &VersionConfig) {
let rbuf = random(4);
let rbuf = random::<4>();
let mut other = Vec::with_capacity(versions.all().len() + 1);
let mut dec = Decoder::new(&rbuf);
let grease = (dec.decode_uint(4).unwrap() as u32) & 0xf0f0_f0f0 | 0x0a0a_0a0a;
Expand Down
2 changes: 1 addition & 1 deletion test-fixture/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl ConnectionIdDecoder for CountingConnectionIdGenerator {

impl ConnectionIdGenerator for CountingConnectionIdGenerator {
fn generate_cid(&mut self) -> Option<ConnectionId> {
let mut r = random(20);
let mut r = random::<20>();
// Randomize length, but ensure that the connection ID is long
// enough to pass for an original destination connection ID.
r[0] = max(8, 5 + ((r[0] >> 4) & r[0]));
Expand Down
Loading

0 comments on commit bb74821

Please sign in to comment.