diff --git a/neqo-bin/benches/main.rs b/neqo-bin/benches/main.rs index fe3aba2714..6bb8b3161d 100644 --- a/neqo-bin/benches/main.rs +++ b/neqo-bin/benches/main.rs @@ -20,7 +20,7 @@ struct Benchmark { fn transfer(c: &mut Criterion) { neqo_common::log::init(Some(log::LevelFilter::Off)); - neqo_crypto::init_db(PathBuf::from_str("../test-fixture/db").unwrap()); + neqo_crypto::init_db(PathBuf::from_str("../test-fixture/db").unwrap()).unwrap(); let done_sender = spawn_server(); diff --git a/neqo-bin/src/bin/server.rs b/neqo-bin/src/bin/server.rs index 8d166c7487..e9b30261e4 100644 --- a/neqo-bin/src/bin/server.rs +++ b/neqo-bin/src/bin/server.rs @@ -7,7 +7,7 @@ use clap::Parser; #[tokio::main] -async fn main() -> Result<(), std::io::Error> { +async fn main() -> Result<(), neqo_bin::server::Error> { let args = neqo_bin::server::Args::parse(); neqo_bin::server::server(args).await diff --git a/neqo-bin/src/client/mod.rs b/neqo-bin/src/client/mod.rs index e0169e3f24..81721802e1 100644 --- a/neqo-bin/src/client/mod.rs +++ b/neqo-bin/src/client/mod.rs @@ -46,6 +46,13 @@ pub enum Error { IoError(io::Error), QlogError, TransportError(neqo_transport::Error), + CryptoError(neqo_crypto::Error), +} + +impl From for Error { + fn from(err: neqo_crypto::Error) -> Self { + Self::CryptoError(err) + } } impl From for Error { @@ -478,11 +485,11 @@ fn qlog_new(args: &Args, hostname: &str, cid: &ConnectionId) -> Res { pub async fn client(mut args: Args) -> Res<()> { neqo_common::log::init(Some(args.verbose.log_level_filter())); - init(); + init()?; args.update_for_tests(); - init(); + init()?; let urls_by_origin = args .urls diff --git a/neqo-bin/src/server/mod.rs b/neqo-bin/src/server/mod.rs index f89d6620de..38eb766f5f 100644 --- a/neqo-bin/src/server/mod.rs +++ b/neqo-bin/src/server/mod.rs @@ -52,6 +52,13 @@ pub enum Error { IoError(io::Error), QlogError, TransportError(neqo_transport::Error), + CryptoError(neqo_crypto::Error), +} + +impl From for Error { + fn from(err: neqo_crypto::Error) -> Self { + Self::CryptoError(err) + } } impl From for Error { @@ -87,6 +94,8 @@ impl Display for Error { impl std::error::Error for Error {} +type Res = Result; + #[derive(Debug, Parser)] #[command(author, version, about, long_about = None)] pub struct Args { @@ -551,7 +560,7 @@ impl ServersRunner { select(sockets_ready, timeout_ready).await.factor_first().0 } - async fn run(&mut self) -> Result<(), io::Error> { + async fn run(&mut self) -> Res<()> { loop { match self.ready().await? { Ready::Socket(inx) => loop { @@ -581,13 +590,13 @@ enum Ready { Timeout, } -pub async fn server(mut args: Args) -> Result<(), io::Error> { +pub async fn server(mut args: Args) -> Res<()> { const HQ_INTEROP: &str = "hq-interop"; neqo_common::log::init(Some(args.verbose.log_level_filter())); assert!(!args.key.is_empty(), "Need at least one key"); - init_db(args.db.clone()); + init_db(args.db.clone())?; if let Some(testcase) = args.shared.qns_test.as_ref() { if args.shared.quic_parameters.quic_version.is_empty() { diff --git a/neqo-crypto/src/lib.rs b/neqo-crypto/src/lib.rs index b82b225d40..2db985e8ee 100644 --- a/neqo-crypto/src/lib.rs +++ b/neqo-crypto/src/lib.rs @@ -90,7 +90,7 @@ impl Drop for NssLoaded { } } -static INITIALIZED: OnceLock = OnceLock::new(); +static INITIALIZED: OnceLock> = OnceLock::new(); fn already_initialized() -> bool { unsafe { nss::NSS_IsInitialized() != 0 } @@ -108,24 +108,24 @@ fn version_check() { /// Initialize NSS. This only executes the initialization routines once, so if there is any chance /// that /// -/// # Panics +/// # Errors /// /// When NSS initialization fails. -pub fn init() { +pub fn init() -> Res<()> { // Set time zero. time::init(); - _ = INITIALIZED.get_or_init(|| { + let res = INITIALIZED.get_or_init(|| { version_check(); if already_initialized() { - return NssLoaded::External; + return Ok(NssLoaded::External); } - secstatus_to_res(unsafe { nss::NSS_NoDB_Init(null()) }).expect("NSS_NoDB_Init failed"); - secstatus_to_res(unsafe { nss::NSS_SetDomesticPolicy() }) - .expect("NSS_SetDomesticPolicy failed"); + secstatus_to_res(unsafe { nss::NSS_NoDB_Init(null()) })?; + secstatus_to_res(unsafe { nss::NSS_SetDomesticPolicy() })?; - NssLoaded::NoDb + Ok(NssLoaded::NoDb) }); + res.as_ref().map(|_| ()).map_err(Clone::clone) } /// This enables SSLTRACE by calling a simple, harmless function to trigger its @@ -133,31 +133,32 @@ pub fn init() { /// global options are accessed. Reading an option is the least impact approach. /// This allows us to use SSLTRACE in all of our unit tests and programs. #[cfg(debug_assertions)] -fn enable_ssl_trace() { +fn enable_ssl_trace() -> Res<()> { let opt = ssl::Opt::Locking.as_int(); let mut v: ::std::os::raw::c_int = 0; secstatus_to_res(unsafe { ssl::SSL_OptionGetDefault(opt, &mut v) }) - .expect("SSL_OptionGetDefault failed"); } /// Initialize with a database. /// -/// # Panics +/// # Errors /// /// If NSS cannot be initialized. -pub fn init_db>(dir: P) { +pub fn init_db>(dir: P) -> Res<()> { time::init(); - _ = INITIALIZED.get_or_init(|| { + let res = INITIALIZED.get_or_init(|| { version_check(); if already_initialized() { - return NssLoaded::External; + return Ok(NssLoaded::External); } let path = dir.into(); - assert!(path.is_dir()); - let pathstr = path.to_str().expect("path converts to string").to_string(); - let dircstr = CString::new(pathstr).unwrap(); - let empty = CString::new("").unwrap(); + if !path.is_dir() { + return Err(Error::InternalError); + } + let pathstr = path.to_str().ok_or(Error::InternalError)?; + let dircstr = CString::new(pathstr)?; + let empty = CString::new("")?; secstatus_to_res(unsafe { nss::NSS_Initialize( dircstr.as_ptr(), @@ -166,21 +167,19 @@ pub fn init_db>(dir: P) { nss::SECMOD_DB.as_ptr().cast(), nss::NSS_INIT_READONLY, ) - }) - .expect("NSS_Initialize failed"); + })?; - secstatus_to_res(unsafe { nss::NSS_SetDomesticPolicy() }) - .expect("NSS_SetDomesticPolicy failed"); + secstatus_to_res(unsafe { nss::NSS_SetDomesticPolicy() })?; secstatus_to_res(unsafe { ssl::SSL_ConfigServerSessionIDCache(1024, 0, 0, dircstr.as_ptr()) - }) - .expect("SSL_ConfigServerSessionIDCache failed"); + })?; #[cfg(debug_assertions)] - enable_ssl_trace(); + enable_ssl_trace()?; - NssLoaded::Db + Ok(NssLoaded::Db) }); + res.as_ref().map(|_| ()).map_err(Clone::clone) } /// # Panics diff --git a/neqo-crypto/tests/init.rs b/neqo-crypto/tests/init.rs index 13218cc340..ee7d808e29 100644 --- a/neqo-crypto/tests/init.rs +++ b/neqo-crypto/tests/init.rs @@ -15,13 +15,7 @@ use neqo_crypto::{assert_initialized, init_db}; // Pull in the NSS internals so that we can ask NSS if it thinks that // it is properly initialized. -#[allow( - dead_code, - non_upper_case_globals, - clippy::redundant_static_lifetimes, - clippy::unseparated_literal_suffix, - clippy::upper_case_acronyms -)] +#[allow(dead_code, non_upper_case_globals)] mod nss { include!(concat!(env!("OUT_DIR"), "/nss_init.rs")); } @@ -29,19 +23,54 @@ mod nss { #[cfg(nss_nodb)] #[test] fn init_nodb() { - init(); + neqo_crypto::init().unwrap(); assert_initialized(); unsafe { - assert!(nss::NSS_IsInitialized() != 0); + assert_ne!(nss::NSS_IsInitialized(), 0); } } +#[cfg(nss_nodb)] +#[test] +fn init_twice_nodb() { + unsafe { + nss::NSS_NoDB_Init(std::ptr::null()); + assert_ne!(nss::NSS_IsInitialized(), 0); + } + // Now do it again + init_nodb(); +} + #[cfg(not(nss_nodb))] #[test] fn init_withdb() { - init_db(::test_fixture::NSS_DB_PATH); + init_db(::test_fixture::NSS_DB_PATH).unwrap(); assert_initialized(); unsafe { - assert!(nss::NSS_IsInitialized() != 0); + assert_ne!(nss::NSS_IsInitialized(), 0); + } +} + +#[cfg(not(nss_nodb))] +#[test] +fn init_twice_withdb() { + use std::{ffi::CString, path::PathBuf}; + + let empty = CString::new("").unwrap(); + let path: PathBuf = ::test_fixture::NSS_DB_PATH.into(); + assert!(path.is_dir()); + let pathstr = path.to_str().unwrap(); + let dircstr = CString::new(pathstr).unwrap(); + unsafe { + nss::NSS_Initialize( + dircstr.as_ptr(), + empty.as_ptr(), + empty.as_ptr(), + nss::SECMOD_DB.as_ptr().cast(), + nss::NSS_INIT_READONLY, + ); + assert_ne!(nss::NSS_IsInitialized(), 0); } + // Now do it again + init_withdb(); } diff --git a/neqo-crypto/tests/selfencrypt.rs b/neqo-crypto/tests/selfencrypt.rs index b20aa27ee6..9fc2162fe2 100644 --- a/neqo-crypto/tests/selfencrypt.rs +++ b/neqo-crypto/tests/selfencrypt.rs @@ -15,7 +15,7 @@ use neqo_crypto::{ #[test] fn se_create() { - init(); + init().unwrap(); SelfEncrypt::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256).expect("constructor works"); } @@ -23,7 +23,7 @@ const PLAINTEXT: &[u8] = b"PLAINTEXT"; const AAD: &[u8] = b"AAD"; fn sealed() -> (SelfEncrypt, Vec) { - init(); + init().unwrap(); let se = SelfEncrypt::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256).unwrap(); let sealed = se.seal(AAD, PLAINTEXT).expect("sealing works"); (se, sealed) diff --git a/test-fixture/src/lib.rs b/test-fixture/src/lib.rs index a6043cd974..e34fb522ff 100644 --- a/test-fixture/src/lib.rs +++ b/test-fixture/src/lib.rs @@ -41,8 +41,12 @@ pub const NSS_DB_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/db"); /// Initialize the test fixture. Only call this if you aren't also calling a /// fixture function that depends on setup. Other functions in the fixture /// that depend on this setup call the function for you. +/// +/// # Panics +/// +/// When the NSS initialization fails. pub fn fixture_init() { - init_db(NSS_DB_PATH); + init_db(NSS_DB_PATH).unwrap(); } // This needs to be > 2ms to avoid it being rounded to zero.