Skip to content

Commit

Permalink
feat(s2n-quic-dc): add ClientConfirm subscriber (#2274)
Browse files Browse the repository at this point in the history
* feat(s2n-quic-dc): add ClientConfirm subscriber

* PR feedback

* handle server case

* clippy
  • Loading branch information
WesleyRosenblum authored Jul 19, 2024
1 parent fcd9a1b commit 59564ff
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 26 deletions.
4 changes: 4 additions & 0 deletions quic/s2n-quic/src/provider/dc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

//! Provides dc support
mod confirm;

use s2n_quic_core::dc::Disabled;

// these imports are only accessible if the unstable feature is enabled
#[allow(unused_imports)]
pub use confirm::ConfirmComplete;
#[allow(unused_imports)]
pub use s2n_quic_core::dc::{ApplicationParams, ConnectionInfo, Endpoint, Path};

pub trait Provider {
Expand Down
144 changes: 144 additions & 0 deletions quic/s2n-quic/src/provider/dc/confirm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::Connection;
use core::task::{Context, Poll, Waker};
use s2n_quic_core::{
connection,
connection::Error,
ensure,
event::{
api as events,
api::{ConnectionInfo, ConnectionMeta, DcState, EndpointType, Subscriber},
},
};
use std::io;

/// `event::Subscriber` used for ensuring an s2n-quic client or server negotiating dc
/// waits for the dc handshake to complete
pub struct ConfirmComplete;
impl ConfirmComplete {
/// Blocks the task until the provided connection has either completed the dc handshake or closed
/// with an error
pub async fn wait_ready(conn: &mut Connection) -> io::Result<()> {
core::future::poll_fn(|cx| {
conn.query_event_context_mut(|context: &mut ConfirmContext| context.poll_ready(cx))
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
})
.await
}
}

#[derive(Default)]
pub struct ConfirmContext {
waker: Option<Waker>,
state: State,
}

impl ConfirmContext {
/// Updates the state on the context
fn update(&mut self, state: State) {
self.state = state;

// notify the application that the state was updated
self.wake();
}

/// Polls the context for handshake completion
fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
match self.state {
// if we're ready or have errored then let the application know
State::Ready => Poll::Ready(Ok(())),
State::Failed(error) => Poll::Ready(Err(error.into())),
State::Waiting(_) => {
// store the waker so we can notify the application of state updates
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

/// notify the application of a state update
fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}

impl Drop for ConfirmContext {
// make sure the application is notified that we're closing the connection
fn drop(&mut self) {
if matches!(self.state, State::Waiting(_)) {
self.state = State::Failed(connection::Error::unspecified());
}
self.wake();
}
}

enum State {
Waiting(Option<DcState>),
Ready,
Failed(connection::Error),
}

impl Default for State {
fn default() -> Self {
State::Waiting(None)
}
}

impl Subscriber for ConfirmComplete {
type ConnectionContext = ConfirmContext;

#[inline]
fn create_connection_context(
&mut self,
_: &ConnectionMeta,
_info: &ConnectionInfo,
) -> Self::ConnectionContext {
ConfirmContext::default()
}

#[inline]
fn on_connection_closed(
&mut self,
context: &mut Self::ConnectionContext,
meta: &ConnectionMeta,
event: &events::ConnectionClosed,
) {
ensure!(matches!(context.state, State::Waiting(_)));

match (&meta.endpoint_type, event.error, &context.state) {
(
EndpointType::Server { .. },
Error::Closed { .. },
State::Waiting(Some(DcState::PathSecretsReady { .. })),
) => {
// The client may close the connection immediately after the dc handshake completes,
// before it sends acknowledgement of the server's DC_STATELESS_RESET_TOKENS.
// Since the server has already moved into the PathSecretsReady state, this can be considered
// as a successful completion of the dc handshake.
context.update(State::Ready)
}
_ => context.update(State::Failed(event.error)),
}
}

#[inline]
fn on_dc_state_changed(
&mut self,
context: &mut Self::ConnectionContext,
_meta: &ConnectionMeta,
event: &events::DcStateChanged,
) {
ensure!(matches!(context.state, State::Waiting(_)));

if let DcState::Complete { .. } = event.state {
// notify the application that the dc handshake has completed
context.update(State::Ready);
} else {
context.update(State::Waiting(Some(event.state.clone())));
}
}
}
86 changes: 60 additions & 26 deletions quic/s2n-quic/src/tests/dc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
// SPDX-License-Identifier: Apache-2.0

use super::*;
use crate::{client, client::ClientProviders, server, server::ServerProviders};
use crate::{client, client::ClientProviders, provider::dc, server, server::ServerProviders};
use s2n_quic_core::{
dc::testing::MockDcEndpoint,
event::{api::DcState, Timestamp},
stateless_reset::token::testing::{TEST_TOKEN_1, TEST_TOKEN_2},
};
use std::io::ErrorKind;

// Client Server
//
Expand Down Expand Up @@ -42,7 +43,7 @@ fn dc_handshake_self_test() {
let server = Server::builder().with_tls(SERVER_CERTS).unwrap();
let client = Client::builder().with_tls(certificates::CERT_PEM).unwrap();

self_test(server, client);
self_test(server, client, None);
}

// Client Server
Expand Down Expand Up @@ -81,17 +82,28 @@ fn dc_mtls_handshake_self_test() {
let client_tls = build_client_mtls_provider(certificates::MTLS_CA_CERT).unwrap();
let client = Client::builder().with_tls(client_tls).unwrap();

self_test(server, client);
self_test(server, client, None);
}

#[test]
fn dc_mtls_handshake_auth_failure_self_test() {
let server_tls = build_server_mtls_provider(certificates::UNTRUSTED_CERT_PEM).unwrap();
let server = Server::builder().with_tls(server_tls).unwrap();

let client_tls = build_client_mtls_provider(certificates::MTLS_CA_CERT).unwrap();
let client = Client::builder().with_tls(client_tls).unwrap();

self_test(server, client, Some(ErrorKind::ConnectionReset));
}

fn self_test<S: ServerProviders, C: ClientProviders>(
server: server::Builder<S>,
client: client::Builder<C>,
expected_error: Option<ErrorKind>,
) {
let model = Model::default();
let rtt = Duration::from_millis(100);
model.set_delay(rtt / 2);
const LEN: usize = 1000;

let server_subscriber = DcStateChanged::new();
let server_events = server_subscriber.clone();
Expand All @@ -103,60 +115,68 @@ fn self_test<S: ServerProviders, C: ClientProviders>(
test(model, |handle| {
let mut server = server
.with_io(handle.builder().build()?)?
.with_event((tracing_events(), server_subscriber))?
.with_event((dc::ConfirmComplete, (tracing_events(), server_subscriber)))?
.with_random(Random::with_seed(456))?
.with_dc(MockDcEndpoint::new(&server_tokens))?
.start()?;

let addr = server.local_addr()?;
spawn(async move {
let mut conn = server.accept().await.unwrap();
let mut stream = conn.open_bidirectional_stream().await.unwrap();
stream.send(vec![42; LEN].into()).await.unwrap();
stream.flush().await.unwrap();
let conn = server.accept().await;
if expected_error.is_some() {
assert!(conn.is_none());
} else {
assert!(dc::ConfirmComplete::wait_ready(&mut conn.unwrap())
.await
.is_ok());
}
});

let client = client
.with_io(handle.builder().build().unwrap())?
.with_event((tracing_events(), client_subscriber))?
.with_event((dc::ConfirmComplete, (tracing_events(), client_subscriber)))?
.with_random(Random::with_seed(456))?
.with_dc(MockDcEndpoint::new(&client_tokens))?
.start()?;

let client_events = client_events.clone();

primary::spawn(async move {
let connect = Connect::new(addr).with_server_name("localhost");
let mut conn = client.connect(connect).await.unwrap();
let mut stream = conn.accept_bidirectional_stream().await.unwrap().unwrap();
let result = dc::ConfirmComplete::wait_ready(&mut conn).await;

let mut recv_len = 0;
while let Some(chunk) = stream.receive().await.unwrap() {
recv_len += chunk.len();
if let Some(error) = expected_error {
assert_eq!(error, result.err().unwrap().kind());
} else {
assert!(result.is_ok());
let client_events = client_events.events().lock().unwrap().clone();
assert_dc_complete(&client_events);
// wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent
// before the client closes the connection. This is only necessary to confirm the `dc::State`
// on the server moves to `DcState::Complete`
delay(Duration::from_millis(100)).await;
}
assert_eq!(LEN, recv_len);
});

Ok(addr)
})
.unwrap();

if expected_error.is_some() {
return;
}

let server_events = server_events.events().lock().unwrap().clone();
let client_events = client_events.events().lock().unwrap().clone();

assert_dc_complete(&server_events);
assert_dc_complete(&client_events);

// 3 state transitions (VersionNegotiated -> PathSecretsReady -> Complete)
assert_eq!(3, server_events.len());
assert_eq!(3, client_events.len());

for events in [server_events.clone(), client_events.clone()] {
if let DcState::VersionNegotiated { version, .. } = events[0].state {
assert_eq!(version, s2n_quic_core::dc::SUPPORTED_VERSIONS[0]);
} else {
panic!("VersionNegotiated should be the first dc state");
}

assert!(matches!(events[1].state, DcState::PathSecretsReady { .. }));
assert!(matches!(events[2].state, DcState::Complete { .. }));
}

// Server path secrets are ready in 1.5 RTTs measured from the start of the test, since it takes
// .5 RTT for the Initial from the client to reach the server
assert_eq!(
Expand All @@ -175,6 +195,20 @@ fn self_test<S: ServerProviders, C: ClientProviders>(
assert_eq!(rtt * 2, client_events[2].timestamp.duration_since_start());
}

fn assert_dc_complete(events: &[DcStateChangedEvent]) {
// 3 state transitions (VersionNegotiated -> PathSecretsReady -> Complete)
assert_eq!(3, events.len());

if let DcState::VersionNegotiated { version, .. } = events[0].state {
assert_eq!(version, s2n_quic_core::dc::SUPPORTED_VERSIONS[0]);
} else {
panic!("VersionNegotiated should be the first dc state");
}

assert!(matches!(events[1].state, DcState::PathSecretsReady { .. }));
assert!(matches!(events[2].state, DcState::Complete { .. }));
}

#[derive(Clone)]
struct DcStateChangedEvent {
timestamp: Timestamp,
Expand Down

0 comments on commit 59564ff

Please sign in to comment.