Skip to content

Commit

Permalink
Less unwrap
Browse files Browse the repository at this point in the history
  • Loading branch information
larseggert committed Dec 19, 2024
1 parent af99cbc commit cf9783c
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 69 deletions.
27 changes: 14 additions & 13 deletions neqo-crypto/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ macro_rules! preinfo_arg {
pub fn $v(&self) -> Option<$t> {
match self.info.valuesSet & ssl::$m {
0 => None,
_ => Some(<$t>::try_from(self.info.$f).unwrap()),
_ => Some(<$t>::try_from(self.info.$f).ok()?),
}
}
};
Expand Down Expand Up @@ -158,12 +158,11 @@ impl SecretAgentPreInfo {
self.info.canSendEarlyData != 0
}

/// # Panics
/// # Errors
///
/// If `usize` is less than 32 bits and the value is too large.
#[must_use]
pub fn max_early_data(&self) -> usize {
usize::try_from(self.info.maxEarlyDataSize).unwrap()
pub fn max_early_data(&self) -> Res<usize> {
usize::try_from(self.info.maxEarlyDataSize).map_err(|_| Error::InternalError)
}

/// Was ECH accepted.
Expand Down Expand Up @@ -872,12 +871,10 @@ impl Client {
arg: *mut c_void,
) -> ssl::SECStatus {
let mut info: MaybeUninit<ssl::SSLResumptionTokenInfo> = MaybeUninit::uninit();
let info_res = &ssl::SSL_GetResumptionTokenInfo(
token,
len,
info.as_mut_ptr(),
c_uint::try_from(mem::size_of::<ssl::SSLResumptionTokenInfo>()).unwrap(),
);
let Ok(info_len) = c_uint::try_from(mem::size_of::<ssl::SSLResumptionTokenInfo>()) else {
return ssl::SECFailure;
};
let info_res = &ssl::SSL_GetResumptionTokenInfo(token, len, info.as_mut_ptr(), info_len);
if info_res.is_err() {
// Ignore the token.
return ssl::SECSuccess;
Expand All @@ -887,8 +884,12 @@ impl Client {
// Ignore the token.
return ssl::SECSuccess;
}
let resumption = arg.cast::<Vec<ResumptionToken>>().as_mut().unwrap();
let len = usize::try_from(len).unwrap();
let Some(resumption) = arg.cast::<Vec<ResumptionToken>>().as_mut() else {
return ssl::SECFailure;
};
let Ok(len) = usize::try_from(len) else {
return ssl::SECFailure;
};
let mut v = Vec::with_capacity(len);
v.extend_from_slice(null_safe_slice(token, len));
qdebug!(
Expand Down
22 changes: 15 additions & 7 deletions neqo-crypto/src/agentio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,15 @@ impl RecordList {
len: c_uint,
arg: *mut c_void,
) -> ssl::SECStatus {
let records = arg.cast::<Self>().as_mut().unwrap();
let Some(records) = arg.cast::<Self>().as_mut() else {
return ssl::SECFailure;
};

let slice = null_safe_slice(data, len);
records.append(epoch, ContentType::try_from(ct).unwrap(), slice);
let Ok(ct) = ContentType::try_from(ct) else {
return ssl::SECFailure;
};
records.append(epoch, ct, slice);
ssl::SECSuccess
}

Expand Down Expand Up @@ -331,7 +336,9 @@ unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 {

#[allow(clippy::cast_possible_truncation)]
unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus {
let a = addr.as_mut().unwrap();
let Some(a) = addr.as_mut() else {
return PR_FAILURE;
};
// Cast is safe because prio::PR_AF_INET is 2
a.inet.family = prio::PR_AF_INET as prio::PRUint16;
a.inet.port = 0;
Expand All @@ -340,10 +347,11 @@ unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrS
}

unsafe extern "C" fn agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus {
let o = opt.as_mut().unwrap();
if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking {
o.value.non_blocking = 1;
return PR_SUCCESS;
if let Some(o) = opt.as_mut() {
if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking {
o.value.non_blocking = 1;
return PR_SUCCESS;
}
}
PR_FAILURE
}
Expand Down
2 changes: 1 addition & 1 deletion neqo-crypto/src/ech.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub fn generate_keys() -> Res<(PrivateKey, PublicKey)> {
params.extend_from_slice(oid_slc);

let mut public_ptr: *mut SECKEYPublicKey = null_mut();
let mut param_item = Item::wrap(&params);
let mut param_item = Item::wrap(&params)?;

// If we have tracing on, try to ensure that key data can be read.
let insensitive_secret_ptr = if log::log_enabled!(log::Level::Trace) {
Expand Down
2 changes: 1 addition & 1 deletion neqo-crypto/src/hkdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub fn import_key(version: Version, buf: &[u8]) -> Res<SymKey> {
CK_MECHANISM_TYPE::from(CKM_HKDF_DERIVE),
PK11Origin::PK11_OriginUnwrap,
CK_ATTRIBUTE_TYPE::from(CKA_DERIVE),
&mut Item::wrap(buf),
&mut Item::wrap(buf)?,
null_mut(),
)
};
Expand Down
4 changes: 2 additions & 2 deletions neqo-crypto/src/hp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl HpKey {
mech,
CK_ATTRIBUTE_TYPE::from(CKA_ENCRYPT),
*key,
&Item::wrap(&ZERO[..0]), // Borrow a zero-length slice of ZERO.
&Item::wrap(&ZERO[..0])?, // Borrow a zero-length slice of ZERO.
)
};
let context = Context::from_ptr(context_ptr).or(Err(Error::CipherInitFailure))?;
Expand Down Expand Up @@ -169,7 +169,7 @@ impl HpKey {
ulNonceBits: 96,
};
let mut output_len: c_uint = 0;
let mut param_item = Item::wrap_struct(&params);
let mut param_item = Item::wrap_struct(&params)?;
secstatus_to_res(unsafe {
PK11_Encrypt(
**key,
Expand Down
32 changes: 10 additions & 22 deletions neqo-crypto/src/p11.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,25 +248,27 @@ impl Item {
/// Creating this object is technically safe, but using it is extremely dangerous.
/// Minimally, it can only be passed as a `const SECItem*` argument to functions,
/// or those that treat their argument as `const`.
pub fn wrap(buf: &[u8]) -> SECItem {
SECItem {
pub fn wrap(buf: &[u8]) -> Res<SECItem> {
let len = c_uint::try_from(buf.len())?;
Ok(SECItem {
type_: SECItemType::siBuffer,
data: buf.as_ptr().cast_mut(),
len: c_uint::try_from(buf.len()).unwrap(),
}
len,
})
}

/// Create a wrapper for a struct.
/// Creating this object is technically safe, but using it is extremely dangerous.
/// Minimally, it can only be passed as a `const SECItem*` argument to functions,
/// or those that treat their argument as `const`.
pub fn wrap_struct<T>(v: &T) -> SECItem {
pub fn wrap_struct<T>(v: &T) -> Res<SECItem> {
let data: *const T = v;
SECItem {
let len = c_uint::try_from(mem::size_of::<T>())?;
Ok(SECItem {
type_: SECItemType::siBuffer,
data: data.cast_mut().cast(),
len: c_uint::try_from(mem::size_of::<T>()).unwrap(),
}
len,
})
}

/// Make an empty `SECItem` for passing as a mutable `SECItem*` argument.
Expand All @@ -277,20 +279,6 @@ impl Item {
len: 0,
}
}

/// This dereferences the pointer held by the item and makes a copy of the
/// content that is referenced there.
///
/// # Safety
///
/// This dereferences two pointers. It doesn't get much less safe.
pub unsafe fn into_vec(self) -> Vec<u8> {
let b = self.ptr.as_ref().unwrap();
// Sanity check the type, as some types don't count bytes in `Item::len`.
assert_eq!(b.type_, SECItemType::siBuffer);
let slc = null_safe_slice(b.data, b.len);
Vec::from(slc)
}
}

unsafe fn destroy_secitem_array(array: *mut SECItemArray) {
Expand Down
4 changes: 2 additions & 2 deletions neqo-crypto/tests/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fn check_client_preinfo(client_preinfo: &SecretAgentPreInfo) {
assert_eq!(client_preinfo.cipher_suite(), None);
assert!(!client_preinfo.early_data());
assert_eq!(client_preinfo.early_data_cipher(), None);
assert_eq!(client_preinfo.max_early_data(), 0);
assert_eq!(client_preinfo.max_early_data(), Ok(0));
assert_eq!(client_preinfo.alpn(), None);
}

Expand All @@ -87,7 +87,7 @@ fn check_server_preinfo(server_preinfo: &SecretAgentPreInfo) {
assert_eq!(server_preinfo.cipher_suite(), Some(TLS_AES_128_GCM_SHA256));
assert!(!server_preinfo.early_data());
assert_eq!(server_preinfo.early_data_cipher(), None);
assert_eq!(server_preinfo.max_early_data(), 0);
assert_eq!(server_preinfo.max_early_data(), Ok(0));
assert_eq!(server_preinfo.alpn(), None);
}

Expand Down
2 changes: 1 addition & 1 deletion neqo-http3/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ impl Http3Connection {
self.send_streams
.remove(&stream_id)
.ok_or(Error::Internal)?,
)));
)?));
self.add_streams(
stream_id,
Box::new(extended_conn.clone()),
Expand Down
40 changes: 20 additions & 20 deletions neqo-http3/src/features/extended_connect/webtransport_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,23 @@ impl WebTransportSession {
///
/// This function is only called with `RecvStream` and `SendStream` that also implement
/// the http specific functions and `http_stream()` will never return `None`.
#[must_use]
pub fn new_with_http_streams(
session_id: StreamId,
events: Box<dyn ExtendedConnectEvents>,
role: Role,
mut control_stream_recv: Box<dyn RecvStream>,
mut control_stream_send: Box<dyn SendStream>,
) -> Self {
) -> Res<Self> {
let stream_event_listener = Rc::new(RefCell::new(WebTransportSessionListener::default()));
control_stream_recv
.http_stream()
.unwrap()
.ok_or(Error::Internal)?
.set_new_listener(Box::new(stream_event_listener.clone()));
control_stream_send
.http_stream()
.unwrap()
.ok_or(Error::Internal)?
.set_new_listener(Box::new(stream_event_listener.clone()));
Self {
Ok(Self {
control_stream_recv,
control_stream_send,
stream_event_listener,
Expand All @@ -127,7 +126,7 @@ impl WebTransportSession {
send_streams: BTreeSet::new(),
recv_streams: BTreeSet::new(),
role,
}
})
}

/// # Errors
Expand All @@ -149,7 +148,7 @@ impl WebTransportSession {
qtrace!([self], "receive control data");
let (out, _) = self.control_stream_recv.receive(conn)?;
debug_assert!(out == ReceiveOutput::NoOutput);
self.maybe_check_headers();
self.maybe_check_headers()?;
self.read_control_stream(conn)?;
Ok((ReceiveOutput::NoOutput, self.state == SessionState::Done))
}
Expand All @@ -161,16 +160,16 @@ impl WebTransportSession {
.ok_or(Error::Internal)?
.header_unblocked(conn)?;
debug_assert!(out == ReceiveOutput::NoOutput);
self.maybe_check_headers();
self.maybe_check_headers()?;
self.read_control_stream(conn)?;
Ok((ReceiveOutput::NoOutput, self.state == SessionState::Done))
}

fn maybe_update_priority(&mut self, priority: Priority) -> bool {
self.control_stream_recv
.http_stream()
.unwrap()
.maybe_update_priority(priority)
let Some(stream) = self.control_stream_recv.http_stream() else {
return false;
};
stream.maybe_update_priority(priority)
}

fn priority_update_frame(&mut self) -> Option<HFrame> {
Expand All @@ -180,10 +179,10 @@ impl WebTransportSession {
}

fn priority_update_sent(&mut self) {
self.control_stream_recv
.http_stream()
.unwrap()
.priority_update_sent();
let Some(stream) = self.control_stream_recv.http_stream() else {
return;
};
stream.priority_update_sent();
}

fn send(&mut self, conn: &mut Connection) -> Res<()> {
Expand Down Expand Up @@ -221,9 +220,9 @@ impl WebTransportSession {
/// # Panics
///
/// This cannot panic because headers are checked before this function called.
pub fn maybe_check_headers(&mut self) {
pub fn maybe_check_headers(&mut self) -> Res<()> {
if SessionState::Negotiating != self.state {
return;
return Ok(());
}

if let Some((headers, interim, fin)) = self.stream_event_listener.borrow_mut().get_headers()
Expand Down Expand Up @@ -257,7 +256,7 @@ impl WebTransportSession {
None
}
})
.unwrap();
.ok_or(Error::Internal)?;

self.state = if (200..300).contains(&status) {
if fin {
Expand Down Expand Up @@ -290,7 +289,8 @@ impl WebTransportSession {
SessionState::Done
};
}
}
};
Ok(())
}

pub fn add_stream(&mut self, stream_id: StreamId) {
Expand Down

0 comments on commit cf9783c

Please sign in to comment.