Skip to content

Commit

Permalink
s2n-tls-tokio: use s2n_shutdown_send instead of s2n_shutdown (#4374)
Browse files Browse the repository at this point in the history
  • Loading branch information
lrstewart authored Jan 31, 2024
1 parent c128140 commit 7227160
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 92 deletions.
12 changes: 8 additions & 4 deletions bindings/rust/s2n-tls-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,19 @@ where
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().poll_blinding(ctx))?;

// s2n_shutdown must not be called again if it errors
// s2n_shutdown_send must not be called again if it errors
if self.shutdown_error.is_none() {
let result = ready!(self.as_mut().with_io(ctx, |mut context| {
context.conn.as_mut().poll_shutdown().map(|r| r.map(|_| ()))
context
.conn
.as_mut()
.poll_shutdown_send()
.map(|r| r.map(|_| ()))
}));
if let Err(error) = result {
self.shutdown_error = Some(error);
// s2n_shutdown reading might have triggered blinding again
ready!(self.as_mut().poll_blinding(ctx))?;
// s2n_shutdown_send only writes, so will never trigger blinding again.
// So we do not need to poll_blinding again after this error.
}
};

Expand Down
111 changes: 23 additions & 88 deletions bindings/rust/s2n-tls-tokio/tests/shutdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ use tokio::{

pub mod common;

// An arbitrary but very long timeout.
// No valid single IO operation should take anywhere near 10 minutes.
pub const LONG_TIMEOUT: time::Duration = time::Duration::from_secs(600);

async fn read_until_shutdown<S: AsyncRead + AsyncWrite + Unpin>(
stream: &mut TlsStream<S>,
) -> Result<(), std::io::Error> {
Expand Down Expand Up @@ -166,18 +162,6 @@ async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
let (mut client, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Attempt to shutdown the client. This will eventually fail because the
// server has not written the close_notify message yet, but it will at least
// write the close_notify message that the server needs.
//
// Because this test begins paused and relies on auto-advancing, this does
// not actually require waiting LONG_TIMEOUT. See the tokio `pause()` docs:
// https://docs.rs/tokio/latest/tokio/time/fn.pause.html
//
// TODO: replace this with a half-close once the bindings support half-close.
let timeout = time::timeout(LONG_TIMEOUT, client.shutdown()).await;
assert!(timeout.is_err());

// Setup a bad record for the next read
overrides.next_read(Some(Box::new(|_, _, buf| {
// Parsing the header is one of the blinded operations
Expand All @@ -202,53 +186,9 @@ async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
// Server MUST eventually successfully shutdown
assert!(result.is_ok());

// Shutdown MUST have sent the close_notify message needed by the peer
// to also shutdown successfully.
client.shutdown().await?;

Ok(())
}

#[tokio::test(start_paused = true)]
async fn shutdown_with_blinding_bad_close_record() -> Result<(), Box<dyn std::error::Error>> {
let clock = common::TokioTime::default();
let mut server_config = common::server_config()?;
server_config.set_monotonic_clock(clock)?;

let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(server_config.build()?);

let (server_stream, client_stream) = common::get_streams().await?;
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();
let (mut client, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Setup a bad record for the next read
overrides.next_read(Some(Box::new(|_, _, buf| {
// Parsing the header is one of the blinded operations
// in s2n_shutdown, so provide a malformed header.
let zeroed_header = [23, 0, 0, 0, 0];
buf.put_slice(&zeroed_header);
Ok(()).into()
})));

let time_start = time::Instant::now();
let result = server.shutdown().await;
let time_elapsed = time_start.elapsed();

// Shutdown MUST NOT complete faster than minimal blinding time.
assert!(time_elapsed > common::MIN_BLINDING_SECS);

// Shutdown MUST eventually complete with the correct error after blinding.
let io_error = result.unwrap_err();
let error: error::Error = io_error.try_into()?;
assert!(error.kind() == error::ErrorType::ProtocolError);
assert!(error.name() == "S2N_ERR_BAD_MESSAGE");

// Shutdown MUST have sent the close_notify message needed by the peer
// to also shutdown successfully.
client.shutdown().await?;
// Shutdown MUST have sent the close_notify message needed for EOF.
let mut received = [0; 1];
assert!(client.read(&mut received).await? == 0);

Ok(())
}
Expand Down Expand Up @@ -295,7 +235,7 @@ async fn shutdown_with_poll_blinding() -> Result<(), Box<dyn std::error::Error>>
Ok(())
}

#[tokio::test(start_paused = true)]
#[tokio::test]
async fn shutdown_with_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);
Expand All @@ -304,20 +244,9 @@ async fn shutdown_with_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();

let (mut client, mut server) =
let (_, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Attempt to shutdown the client. This will eventually fail because the
// server has not written the close_notify message yet, but it will at least
// write the close_notify message that the server needs.
//
// Because this test begins paused and relies on auto-advancing, this does
// not actually require waiting LONG_TIMEOUT. See the tokio `pause()` docs:
// https://docs.rs/tokio/latest/tokio/time/fn.pause.html
//
// TODO: replace this with a half-close once the bindings support half-close.
_ = time::timeout(time::Duration::from_secs(600), client.shutdown()).await;

// The underlying stream should return a unique error on shutdown
overrides.next_shutdown(Some(Box::new(|_, _| {
Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
Expand All @@ -343,22 +272,22 @@ async fn shutdown_with_tls_error_and_tcp_error() -> Result<(), Box<dyn std::erro
let (_, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Both s2n_shutdown and the underlying stream should error on shutdown
overrides.next_read(Some(Box::new(|_, _, _| {
// Both s2n_shutdown_send and the underlying stream should error on shutdown
overrides.next_write(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::Other)))
})));
overrides.next_shutdown(Some(Box::new(|_, _| {
Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
})));

// Shutdown should complete with the correct error from s2n_shutdown
// Shutdown should complete with the correct error from s2n_shutdown_send
let result = server.shutdown().await;
let io_error = result.unwrap_err();
let error: error::Error = io_error.try_into()?;
// Any non-blocking read error is translated as "IOError"
assert!(error.kind() == error::ErrorType::IOError);

// Even if s2n_shutdown fails, we need to close the underlying stream.
// Even if s2n_shutdown_send fails, we need to close the underlying stream.
// Make sure we called our mock shutdown, consuming it.
assert!(overrides.is_consumed());

Expand All @@ -374,14 +303,11 @@ async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box<dyn std::erro
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();

let (_, mut server) =
let (mut client, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// We want s2n_shutdown to fail on read in order to ensure that it is only
// called once on failure.
// If s2n_shutdown were called again, the second call would hang waiting
// for nonexistent input from the peer.
overrides.next_read(Some(Box::new(|_, _, _| {
// We want s2n_shutdown_send to produce an error on write
overrides.next_write(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::Other)))
})));

Expand All @@ -391,16 +317,25 @@ async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box<dyn std::erro
Pending
})));

// Shutdown should complete with the correct error from s2n_shutdown
// Shutdown should complete with the correct error from s2n_shutdown_send
let result = server.shutdown().await;
let io_error = result.unwrap_err();
let error: error::Error = io_error.try_into()?;
// Any non-blocking read error is translated as "IOError"
assert!(error.kind() == error::ErrorType::IOError);

// Even if s2n_shutdown fails, we need to close the underlying stream.
// Even if s2n_shutdown_send fails, we need to close the underlying stream.
// Make sure we at least called our mock shutdown, consuming it.
assert!(overrides.is_consumed());

// Since s2n_shutdown_send failed, we should NOT have sent a close_notify.
// Make sure the peer doesn't receive a close_notify.
// If this is not true, then we're incorrectly calling s2n_shutdown_send
// again after an error.
let mut received = [0; 1];
let io_error = client.read(&mut received).await.unwrap_err();
let error: error::Error = io_error.try_into()?;
assert!(error.kind() == error::ErrorType::ConnectionClosed);

Ok(())
}
16 changes: 16 additions & 0 deletions bindings/rust/s2n-tls/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,22 @@ impl Connection {
}
}

/// Attempts a graceful shutdown of the write side of a TLS connection.
///
/// Unlike Self::poll_shutdown, no reponse from the peer is necessary.
/// If using TLS1.3, the connection can continue to be used for reading afterwards.
pub fn poll_shutdown_send(&mut self) -> Poll<Result<&mut Self, Error>> {
if !self.remaining_blinding_delay()?.is_zero() {
return Poll::Pending;
}
let mut blocked = s2n_blocked_status::NOT_BLOCKED;
unsafe {
s2n_shutdown_send(self.connection.as_ptr(), &mut blocked)
.into_poll()
.map_ok(|_| self)
}
}

/// Returns the TLS alert code, if any
pub fn alert(&self) -> Option<u8> {
let alert =
Expand Down

0 comments on commit 7227160

Please sign in to comment.