From 4e03c47b5969cd979de736cecdc851a2a3f7b193 Mon Sep 17 00:00:00 2001 From: Jakob Schikowski Date: Mon, 9 Sep 2024 00:30:26 +0200 Subject: [PATCH] refactor!: Reimplement stream without split --- Cargo.toml | 2 +- README.md | 2 +- examples/client.rs | 2 +- examples/client_authenticate.rs | 2 +- examples/client_idle.rs | 2 +- examples/server.rs | 2 +- examples/server_authenticate.rs | 2 +- examples/server_idle.rs | 2 +- integration-test/Cargo.toml | 2 +- integration-test/src/client_tester.rs | 9 +- integration-test/src/server_tester.rs | 15 +- src/stream.rs | 274 +++++++------------------- src/tests.rs | 4 +- 13 files changed, 103 insertions(+), 217 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c436d9f..f2969e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ ext_metadata = ["imap-codec/ext_metadata"] bytes = { version = "1.7.1", optional = true } imap-codec = { version = "2.0.0-alpha.5", features = ["quirk_crlf_relaxed"] } thiserror = "1.0.63" -tokio = { version = "1.40.0", optional = true, features = ["io-util", "macros", "net"] } +tokio = { version = "1.40.0", optional = true, features = ["io-util", "net"] } tokio-rustls = { version = "0.26.0", optional = true, default-features = false } tracing = "0.1.40" diff --git a/README.md b/README.md index 54cf4ca..3da50ae 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ use tokio::net::TcpStream; #[tokio::main] async fn main() -> Result<(), Box> { - let mut stream = Stream::insecure(TcpStream::connect("127.0.0.1:1143").await?); + let mut stream = Stream::new(TcpStream::connect("127.0.0.1:1143").await?); let mut client = Client::new(Options::default()); loop { diff --git a/examples/client.rs b/examples/client.rs index 73ac6e3..3c46700 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -11,7 +11,7 @@ use tokio::net::TcpStream; #[tokio::main(flavor = "current_thread")] async fn main() { let stream = TcpStream::connect("127.0.0.1:12345").await.unwrap(); - let mut stream = Stream::insecure(stream); + let mut stream = Stream::new(stream); let mut client = Client::new(Options::default()); let greeting = loop { diff --git a/examples/client_authenticate.rs b/examples/client_authenticate.rs index 4b56ce7..c7146be 100644 --- a/examples/client_authenticate.rs +++ b/examples/client_authenticate.rs @@ -14,7 +14,7 @@ use tokio::net::TcpStream; #[tokio::main(flavor = "current_thread")] async fn main() { let stream = TcpStream::connect("127.0.0.1:12345").await.unwrap(); - let mut stream = Stream::insecure(stream); + let mut stream = Stream::new(stream); let mut client = Client::new(Options::default()); loop { diff --git a/examples/client_idle.rs b/examples/client_idle.rs index cb5d149..c652bdc 100644 --- a/examples/client_idle.rs +++ b/examples/client_idle.rs @@ -14,7 +14,7 @@ use tokio::{net::TcpStream, sync::mpsc::Receiver}; #[tokio::main(flavor = "current_thread")] async fn main() { let stream = TcpStream::connect("127.0.0.1:12345").await.unwrap(); - let mut stream = Stream::insecure(stream); + let mut stream = Stream::new(stream); let mut client = Client::new(Options::default()); loop { diff --git a/examples/server.rs b/examples/server.rs index f465202..3c032b4 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -11,7 +11,7 @@ use tokio::net::TcpListener; async fn main() { let listener = TcpListener::bind("127.0.0.1:12345").await.unwrap(); let (stream, _) = listener.accept().await.unwrap(); - let mut stream = Stream::insecure(stream); + let mut stream = Stream::new(stream); let mut server = Server::new( Options::default(), Greeting::ok(None, "server (example)").unwrap(), diff --git a/examples/server_authenticate.rs b/examples/server_authenticate.rs index 1e9d5ff..8254a31 100644 --- a/examples/server_authenticate.rs +++ b/examples/server_authenticate.rs @@ -10,7 +10,7 @@ use tokio::net::TcpListener; async fn main() { let listener = TcpListener::bind("127.0.0.1:12345").await.unwrap(); let (stream, _) = listener.accept().await.unwrap(); - let mut stream = Stream::insecure(stream); + let mut stream = Stream::new(stream); let mut server = Server::new( Options::default(), Greeting::ok(None, "server_idle (example)").unwrap(), diff --git a/examples/server_idle.rs b/examples/server_idle.rs index 2a21dcc..1e2dfea 100644 --- a/examples/server_idle.rs +++ b/examples/server_idle.rs @@ -16,7 +16,7 @@ use tokio::{net::TcpListener, sync::mpsc::Receiver}; async fn main() { let listener = TcpListener::bind("127.0.0.1:12345").await.unwrap(); let (stream, _) = listener.accept().await.unwrap(); - let mut stream = Stream::insecure(stream); + let mut stream = Stream::new(stream); let mut server = Server::new( Options::default(), Greeting::ok(None, "server_idle (example)").unwrap(), diff --git a/integration-test/Cargo.toml b/integration-test/Cargo.toml index d45c2f0..5ed2d43 100644 --- a/integration-test/Cargo.toml +++ b/integration-test/Cargo.toml @@ -8,7 +8,7 @@ publish = false [dependencies] bstr = { version = "1.10.0", default-features = false } bytes = "1.7.1" -imap-codec = { version = "2.0.0-alpha.4" } +imap-codec = { version = "2.0.0-alpha.5" } imap-next = { path = ".." } tokio = { version = "1.40.0", features = ["macros", "net", "rt", "time"] } tracing = "0.1.40" diff --git a/integration-test/src/client_tester.rs b/integration-test/src/client_tester.rs index 05d65d9..af751df 100644 --- a/integration-test/src/client_tester.rs +++ b/integration-test/src/client_tester.rs @@ -25,7 +25,7 @@ impl ClientTester { ) -> Self { let stream = TcpStream::connect(server_address).await.unwrap(); trace!(?server_address, "Client is connected"); - let stream = Stream::insecure(stream); + let stream = Stream::new(stream); let client = Client::new(client_options); Self { codecs, @@ -344,13 +344,16 @@ impl ClientTester { #[allow(clippy::large_enum_variant)] enum ConnectionState { /// Connection to server established. - Connected { stream: Stream, client: Client }, + Connected { + stream: Stream, + client: Client, + }, /// Connection dropped. Disconnected, } impl ConnectionState { - fn connected(&mut self) -> (&mut Stream, &mut Client) { + fn connected(&mut self) -> (&mut Stream, &mut Client) { match self { ConnectionState::Connected { stream, client } => (stream, client), ConnectionState::Disconnected => panic!("Client is already disconnected"), diff --git a/integration-test/src/server_tester.rs b/integration-test/src/server_tester.rs index 153e0fb..49aa803 100644 --- a/integration-test/src/server_tester.rs +++ b/integration-test/src/server_tester.rs @@ -4,7 +4,7 @@ use imap_next::{ server::{self, ResponseHandle, Server}, stream::{self, Stream}, }; -use tokio::net::TcpListener; +use tokio::net::{TcpListener, TcpStream}; use tracing::trace; use crate::codecs::Codecs; @@ -24,7 +24,7 @@ impl ServerTester { ) -> Self { let (stream, client_address) = server_listener.accept().await.unwrap(); trace!(?client_address, "Server accepts connection"); - let stream = Stream::insecure(stream); + let stream = Stream::new(stream); Self { codecs, server_options, @@ -308,15 +308,20 @@ impl ServerTester { #[allow(clippy::large_enum_variant)] enum ConnectionState { // Connection to client established. - Connected { stream: Stream }, + Connected { + stream: Stream, + }, // Server greeted client. - Greeted { stream: Stream, server: Server }, + Greeted { + stream: Stream, + server: Server, + }, // Connection dropped. Disconnected, } impl ConnectionState { - fn greeted(&mut self) -> (&mut Stream, &mut Server) { + fn greeted(&mut self) -> (&mut Stream, &mut Server) { match self { ConnectionState::Connected { .. } => panic!("Server has not greeted yet"), ConnectionState::Greeted { stream, server } => (stream, server), diff --git a/src/stream.rs b/src/stream.rs index 057cf1c..7acdd65 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,109 +1,73 @@ use std::{ convert::Infallible, - io::{ErrorKind, Read, Write}, + future::{poll_fn, Future}, + pin::pin, + task::{Context, Poll}, }; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{Buf, BytesMut}; #[cfg(debug_assertions)] use imap_codec::imap_types::utils::escape_byte_string; use thiserror::Error; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, - net::TcpStream, - select, -}; -use tokio_rustls::{rustls, TlsStream}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio_rustls::rustls; #[cfg(debug_assertions)] use tracing::trace; use crate::{Interrupt, Io, State}; -pub struct Stream { - stream: TcpStream, - tls: Option, +pub struct Stream { + stream: S, read_buffer: BytesMut, write_buffer: BytesMut, } -impl Stream { - pub fn insecure(stream: TcpStream) -> Self { +impl Stream { + pub fn new(stream: S) -> Self { Self { stream, - tls: None, read_buffer: BytesMut::default(), write_buffer: BytesMut::default(), } } +} - pub fn tls(stream: TlsStream) -> Self { - // We want to use `TcpStream::split` for handling reading and writing separately, - // but `TlsStream` does not expose this functionality. Therefore, we destruct `TlsStream` - // into `TcpStream` and `rustls::Connection` and handling them ourselves. - // - // Some notes: - // - // - There is also `tokio::io::split` which works for all kind of streams. But this - // involves too much scary magic because its use-case is reading and writing from - // different threads. We prefer to use the more low-level `TcpStream::split`. - // - // - We could get rid of `TlsStream` and construct `rustls::Connection` directly. - // But `TlsStream` is still useful because it gives us the guarantee that the handshake - // was already handled properly. - // - // - In the long run it would be nice if `TlsStream::split` would exist and we would use - // it because `TlsStream` is better at handling the edge cases of `rustls`. - let (stream, tls) = match stream { - TlsStream::Client(stream) => { - let (stream, tls) = stream.into_inner(); - (stream, rustls::Connection::Client(tls)) - } - TlsStream::Server(stream) => { - let (stream, tls) = stream.into_inner(); - (stream, rustls::Connection::Server(tls)) - } - }; +impl Stream { + #[cfg(feature = "expose_stream")] + /// Return the underlying stream for debug purposes (or experiments). + /// + /// Note: Writing to or reading from the stream may introduce + /// conflicts with `imap-next`. + pub fn stream_mut(&mut self) -> &mut S { + &mut self.stream + } - Self { - stream, - tls: Some(tls), - read_buffer: BytesMut::default(), - write_buffer: BytesMut::default(), - } + /// Take the underlying stream out of a [`Stream`]. + /// + /// Useful when a TCP stream needs to be upgraded to a TLS one. + #[cfg(feature = "expose_stream")] + pub fn into_stream(self) -> S { + self.stream } +} +impl Stream { pub async fn flush(&mut self) -> Result<(), Error> { - // Flush TLS - if let Some(tls) = &mut self.tls { - tls.writer().flush()?; - encrypt(tls, &mut self.write_buffer, Vec::new())?; - } - // Flush TCP - write(&mut self.stream, &mut self.write_buffer).await?; + poll_fn(|cx| poll_write_stream(&mut self.stream, cx, &mut self.write_buffer)).await?; self.stream.flush().await?; Ok(()) } +} +impl Stream { pub async fn next(&mut self, mut state: F) -> Result> { let event = loop { - match &mut self.tls { - None => { - // Provide input bytes to the client/server - if !self.read_buffer.is_empty() { - state.enqueue_input(&self.read_buffer); - self.read_buffer.clear(); - } - } - Some(tls) => { - // Decrypt input bytes - let plain_bytes = decrypt(tls, &mut self.read_buffer)?; - - // Provide input bytes to the client/server - if !plain_bytes.is_empty() { - state.enqueue_input(&plain_bytes); - } - } + // Provide input bytes to the client/server + if !self.read_buffer.is_empty() { + state.enqueue_input(&self.read_buffer); + self.read_buffer.clear(); } // Progress the client/server @@ -121,62 +85,32 @@ impl Stream { Interrupt::Error(err) => return Err(Error::State(err)), }; - match &mut self.tls { - None => { - // Handle the output bytes from the client/server - if let Io::Output(bytes) = io { - self.write_buffer.extend(bytes); - } - } - Some(tls) => { - // Handle the output bytes from the client/server - let plain_bytes = if let Io::Output(bytes) = io { - bytes - } else { - Vec::new() - }; - - // Encrypt output bytes - encrypt(tls, &mut self.write_buffer, plain_bytes)?; - } + // Handle the output bytes from the client/server + if let Io::Output(bytes) = io { + self.write_buffer.extend(bytes); } // Progress the stream if self.write_buffer.is_empty() { - read(&mut self.stream, &mut self.read_buffer).await?; + poll_fn(|cx| poll_read_stream(&mut self.stream, cx, &mut self.read_buffer)).await?; } else { // We read and write the stream simultaneously because otherwise // a deadlock between client and server might occur if both sides // would only read or only write. - let (read_stream, write_stream) = self.stream.split(); - select! { - result = read(read_stream, &mut self.read_buffer) => result, - result = write(write_stream, &mut self.write_buffer) => result, - }?; + poll_fn(|cx| { + match poll_write_stream(&mut self.stream, cx, &mut self.write_buffer) { + Poll::Ready(result) => Poll::Ready(result), + Poll::Pending => { + poll_read_stream(&mut self.stream, cx, &mut self.read_buffer) + } + } + }) + .await?; }; }; Ok(event) } - - #[cfg(feature = "expose_stream")] - /// Return the underlying stream for debug purposes (or experiments). - /// - /// Note: Writing to or reading from the stream may introduce - /// conflicts with `imap-next`. - pub fn stream_mut(&mut self) -> &mut TcpStream { - &mut self.stream - } -} - -/// Take the [`TcpStream`] out of a [`Stream`]. -/// -/// Useful when a TCP stream needs to be upgraded to a TLS one. -#[cfg(feature = "expose_stream")] -impl From for TcpStream { - fn from(stream: Stream) -> Self { - stream.stream - } } /// Error during reading into or writing from a stream. @@ -199,13 +133,21 @@ pub enum Error { State(E), } -async fn read( - mut stream: S, +fn poll_read_stream( + stream: &mut S, + cx: &mut Context<'_>, read_buffer: &mut BytesMut, -) -> Result<(), ReadWriteError> { +) -> Poll> { #[cfg(debug_assertions)] let old_len = read_buffer.len(); - let byte_count = stream.read_buf(read_buffer).await?; + + // Constructing this future is cheap + let read_buf_future = pin!(stream.read_buf(read_buffer)); + let Poll::Ready(read_buf_result) = read_buf_future.poll(cx) else { + return Poll::Pending; + }; + let byte_count = read_buf_result?; + #[cfg(debug_assertions)] trace!( data = escape_byte_string(&read_buffer[old_len..]), @@ -216,34 +158,42 @@ async fn read( // The result is 0 if the stream reached "end of file" or the read buffer was // already full before calling `read_buf`. Because we use an unlimited buffer we // know that the first case occurred. - return Err(ReadWriteError::Closed); + return Poll::Ready(Err(ReadWriteError::Closed)); } - Ok(()) + Poll::Ready(Ok(())) } -async fn write( - mut stream: S, +fn poll_write_stream( + stream: &mut S, + cx: &mut Context<'_>, write_buffer: &mut BytesMut, -) -> Result<(), ReadWriteError> { +) -> Poll> { while !write_buffer.is_empty() { - let byte_count = stream.write(write_buffer).await?; + // Constructing this future is cheap + let write_future = pin!(stream.write(write_buffer)); + let Poll::Ready(write_result) = write_future.poll(cx) else { + return Poll::Pending; + }; + let byte_count = write_result?; + #[cfg(debug_assertions)] trace!( data = escape_byte_string(&write_buffer[..byte_count]), "io/write/raw" ); + write_buffer.advance(byte_count); if byte_count == 0 { // The result is 0 if the stream doesn't accept bytes anymore or the write buffer // was already empty before calling `write_buf`. Because we checked the buffer // we know that the first case occurred. - return Err(ReadWriteError::Closed); + return Poll::Ready(Err(ReadWriteError::Closed)); } } - Ok(()) + Poll::Ready(Ok(())) } #[derive(Debug, Error)] @@ -262,75 +212,3 @@ impl From for Error { } } } - -fn decrypt( - tls: &mut rustls::Connection, - read_buffer: &mut BytesMut, -) -> Result, DecryptEncryptError> { - let mut plain_bytes = Vec::new(); - - while tls.wants_read() && !read_buffer.is_empty() { - let mut encrypted_bytes = read_buffer.reader(); - tls.read_tls(&mut encrypted_bytes)?; - tls.process_new_packets()?; - } - - loop { - let mut plain_bytes_chunk = [0; 128]; - // We need to handle different cases according to: - // https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read - match tls.reader().read(&mut plain_bytes_chunk) { - // There are no more bytes to read - Err(err) if err.kind() == ErrorKind::WouldBlock => break, - // The TLS session was closed uncleanly - Err(err) if err.kind() == ErrorKind::UnexpectedEof => { - return Err(DecryptEncryptError::Closed) - } - // We got an unexpected error - Err(err) => return Err(DecryptEncryptError::Io(err)), - // The TLS session was closed cleanly - Ok(0) => return Err(DecryptEncryptError::Closed), - // We read some plaintext bytes - Ok(n) => plain_bytes.extend(&plain_bytes_chunk[0..n]), - }; - } - - Ok(plain_bytes) -} - -fn encrypt( - tls: &mut rustls::Connection, - write_buffer: &mut BytesMut, - plain_bytes: Vec, -) -> Result<(), DecryptEncryptError> { - if !plain_bytes.is_empty() { - tls.writer().write_all(&plain_bytes)?; - } - - while tls.wants_write() { - let mut encrypted_bytes = write_buffer.writer(); - tls.write_tls(&mut encrypted_bytes)?; - } - - Ok(()) -} - -#[derive(Debug, Error)] -enum DecryptEncryptError { - #[error("Session was closed")] - Closed, - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Tls(#[from] rustls::Error), -} - -impl From for Error { - fn from(value: DecryptEncryptError) -> Self { - match value { - DecryptEncryptError::Closed => Error::Closed, - DecryptEncryptError::Io(err) => Error::Io(err), - DecryptEncryptError::Tls(err) => Error::Tls(err), - } - } -} diff --git a/src/tests.rs b/src/tests.rs index 7132e05..9366462 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -25,7 +25,7 @@ async fn self_test() { async move { let (stream, _) = listener.accept().await.unwrap(); - let mut stream = Stream::insecure(stream); + let mut stream = Stream::new(stream); let mut server = Server::new(server::Options::default(), greeting.clone()); loop { @@ -50,7 +50,7 @@ async fn self_test() { let _ = tokio::task::spawn(server); let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap(); - let mut stream = Stream::insecure(stream); + let mut stream = Stream::new(stream); let mut client = Client::new(client::Options::default()); client.enqueue_command(Command::new(Tag::unvalidated("A1"), CommandBody::Capability).unwrap());