diff --git a/Cargo.toml b/Cargo.toml index 9749d786..a91fa37f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,9 @@ name = "imap-flow" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +bounded-static = "0.5.0" +bytes = "1.5.0" +imap-codec = { version = "1.0.0", features = ["quirk_crlf_relaxed", "bounded-static"] } +thiserror = "1.0.49" +tokio = { version = "1.32.0", features = ["io-util"] } diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 00000000..f629f121 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,188 @@ +use bounded_static::ToBoundedStatic; +use bytes::BytesMut; +use imap_codec::{ + decode::{GreetingDecodeError, ResponseDecodeError}, + imap_types::{ + command::Command, + core::Tag, + response::{Data, Greeting, Response, Status}, + }, + CommandCodec, GreetingCodec, ResponseCodec, +}; +use thiserror::Error; + +use crate::{ + receive::{ReceiveEvent, ReceiveState}, + send::SendCommandState, + stream::AnyStream, +}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ClientFlowOptions { + pub crlf_relaxed: bool, +} + +pub struct ClientFlow { + stream: AnyStream, + + next_command_handle: ClientFlowCommandHandle, + send_command_state: SendCommandState<(Tag<'static>, ClientFlowCommandHandle)>, + receive_response_state: ReceiveState, +} + +impl ClientFlow { + pub async fn receive_greeting( + mut stream: AnyStream, + options: ClientFlowOptions, + ) -> Result<(Self, Greeting<'static>), ClientFlowError> { + // Receive greeting + let read_buffer = BytesMut::new(); + let mut receive_greeting_state = + ReceiveState::new(GreetingCodec::default(), options.crlf_relaxed, read_buffer); + let greeting = match receive_greeting_state.progress(&mut stream).await? { + ReceiveEvent::DecodingSuccess(greeting) => { + receive_greeting_state.finish_message(); + greeting + } + ReceiveEvent::DecodingFailure( + GreetingDecodeError::Failed | GreetingDecodeError::Incomplete, + ) => { + let discarded_bytes = receive_greeting_state.discard_message(); + return Err(ClientFlowError::MalformedMessage { discarded_bytes }); + } + ReceiveEvent::ExpectedCrlfGotLf => { + let discarded_bytes = receive_greeting_state.discard_message(); + return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes }); + } + }; + + // Successfully received greeting, create instance. + let write_buffer = BytesMut::new(); + let send_command_state = SendCommandState::new(CommandCodec::default(), write_buffer); + let read_buffer = receive_greeting_state.finish(); + let receive_response_state = + ReceiveState::new(ResponseCodec::new(), options.crlf_relaxed, read_buffer); + let client_flow = Self { + stream, + next_command_handle: ClientFlowCommandHandle(0), + send_command_state, + receive_response_state, + }; + + Ok((client_flow, greeting)) + } + + pub fn enqueue_command(&mut self, command: Command<'_>) -> ClientFlowCommandHandle { + let handle = self.next_command_handle; + self.next_command_handle = ClientFlowCommandHandle(handle.0 + 1); + let tag = command.tag.to_static(); + self.send_command_state.enqueue((tag, handle), command); + handle + } + + pub async fn progress(&mut self) -> Result { + loop { + if let Some(event) = self.progress_command().await? { + return Ok(event); + } + + if let Some(event) = self.progress_response().await? { + return Ok(event); + } + } + } + + async fn progress_command(&mut self) -> Result, ClientFlowError> { + match self.send_command_state.progress(&mut self.stream).await? { + Some((tag, handle)) => Ok(Some(ClientFlowEvent::CommandSent { tag, handle })), + None => Ok(None), + } + } + + async fn progress_response(&mut self) -> Result, ClientFlowError> { + let event = loop { + let response = match self + .receive_response_state + .progress(&mut self.stream) + .await? + { + ReceiveEvent::DecodingSuccess(response) => { + self.receive_response_state.finish_message(); + response + } + ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length }) => { + // The client must accept the literal in any case. + self.receive_response_state.start_literal(length); + continue; + } + ReceiveEvent::DecodingFailure( + ResponseDecodeError::Failed | ResponseDecodeError::Incomplete, + ) => { + let discarded_bytes = self.receive_response_state.discard_message(); + return Err(ClientFlowError::MalformedMessage { discarded_bytes }); + } + ReceiveEvent::ExpectedCrlfGotLf => { + let discarded_bytes = self.receive_response_state.discard_message(); + return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes }); + } + }; + + match response { + Response::Status(status) => { + self.maybe_abort_command(&status); + break Some(ClientFlowEvent::StatusReceived { status }); + } + Response::Data(data) => break Some(ClientFlowEvent::DataReceived { data }), + Response::CommandContinuationRequest(_) => { + self.send_command_state.continue_command(); + break None; + } + } + }; + + Ok(event) + } + + fn maybe_abort_command(&mut self, status: &Status) { + let Some((command_tag, _)) = self.send_command_state.command_in_progress() else { + return; + }; + + match status { + Status::Bad { + tag: Some(status_tag), + .. + } if status_tag == command_tag => { + self.send_command_state.abort_command(); + } + _ => (), + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct ClientFlowCommandHandle(u64); + +#[derive(Debug)] +pub enum ClientFlowEvent { + CommandSent { + tag: Tag<'static>, + handle: ClientFlowCommandHandle, + }, + DataReceived { + data: Data<'static>, + }, + StatusReceived { + status: Status<'static>, + }, +} + +#[derive(Debug, Error)] +pub enum ClientFlowError { + #[error(transparent)] + Io(#[from] tokio::io::Error), + #[error("Expected `\\r\\n`, got `\\n`")] + ExpectedCrlfGotLf { discarded_bytes: Box<[u8]> }, + #[error("Received malformed message")] + MalformedMessage { discarded_bytes: Box<[u8]> }, +} diff --git a/src/lib.rs b/src/lib.rs index 7d12d9af..92862d07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,5 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} +pub mod client; +mod receive; +mod send; +pub mod server; +pub mod stream; diff --git a/src/receive.rs b/src/receive.rs new file mode 100644 index 00000000..b1a05e0d --- /dev/null +++ b/src/receive.rs @@ -0,0 +1,171 @@ +use bounded_static::IntoBoundedStatic; +use bytes::{Buf, BytesMut}; +use imap_codec::decode::Decoder; +use tokio::io::AsyncReadExt; + +use crate::stream::AnyStream; + +pub struct ReceiveState { + codec: C, + crlf_relaxed: bool, + next_fragment: NextFragment, + // How many bytes in the parse buffer do we already have checked? + // This is important if we need multiple attempts to read from the underlying + // stream before the message is completely received. + seen_bytes: usize, + // Used for reading the current message from the stream. + // Its length should always be equal to or greater than `seen_bytes`. + read_buffer: BytesMut, +} + +impl ReceiveState { + pub fn new(codec: C, crlf_relaxed: bool, read_buffer: BytesMut) -> Self { + Self { + codec, + crlf_relaxed, + next_fragment: NextFragment::default(), + seen_bytes: 0, + read_buffer, + } + } + + pub fn start_literal(&mut self, length: u32) { + self.next_fragment = NextFragment::Literal { length }; + self.read_buffer.reserve(length as usize); + } + + pub fn finish_message(&mut self) { + self.read_buffer.advance(self.seen_bytes); + self.seen_bytes = 0; + self.next_fragment = NextFragment::default(); + } + + pub fn discard_message(&mut self) -> Box<[u8]> { + let discarded_bytes = self.read_buffer[..self.seen_bytes].into(); + self.finish_message(); + discarded_bytes + } + + pub fn finish(self) -> BytesMut { + self.read_buffer + } + + pub async fn progress( + &mut self, + stream: &mut AnyStream, + ) -> Result, tokio::io::Error> + where + for<'a> C::Message<'a>: IntoBoundedStatic>, + for<'a> C::Error<'a>: IntoBoundedStatic>, + { + loop { + match self.next_fragment { + NextFragment::Line => { + if let Some(event) = self.progress_line(stream).await? { + return Ok(event); + } + } + NextFragment::Literal { length } => { + self.progress_literal(stream, length).await?; + } + }; + } + } + + async fn progress_line( + &mut self, + stream: &mut AnyStream, + ) -> Result>, tokio::io::Error> + where + for<'a> C::Message<'a>: IntoBoundedStatic>, + for<'a> C::Error<'a>: IntoBoundedStatic>, + { + // TODO: If the line is really long and we need multiple attempts to receive it, then this is O(n^2). + // This could be fixed by setting seen bytes in the None case + let crlf_result = match find_crlf(&self.read_buffer[self.seen_bytes..], self.crlf_relaxed) { + Some(crlf_result) => crlf_result, + None => { + // No full line received yet, more data needed. + stream.0.read_buf(&mut self.read_buffer).await?; + return Ok(None); + } + }; + + // Mark the all bytes of the current line as seen. + self.seen_bytes += crlf_result.lf_position + 1; + + if crlf_result.expected_crlf_got_lf { + return Ok(Some(ReceiveEvent::ExpectedCrlfGotLf)); + } + + // Try to parse the whole message from the start (including the new line). + // TODO: If the message is really long and we need multiple attempts to receive it, then this is O(n^2) + // IMO this can be only fixed by using a generator-like decoder + match self.codec.decode(&self.read_buffer[..self.seen_bytes]) { + Ok((remaining, message)) => { + assert!(remaining.is_empty()); + Ok(Some(ReceiveEvent::DecodingSuccess(message.into_static()))) + } + Err(error) => Ok(Some(ReceiveEvent::DecodingFailure(error.into_static()))), + } + } + + async fn progress_literal( + &mut self, + stream: &mut AnyStream, + literal_length: u32, + ) -> Result<(), tokio::io::Error> { + let unseen_bytes = self.read_buffer.len() - self.seen_bytes; + + if unseen_bytes < literal_length as usize { + // We did not receive enough bytes for the literal yet. + stream.0.read_buf(&mut self.read_buffer).await?; + } else { + // We received enough bytes for the literal. + // Now we can continue reading the next line. + self.next_fragment = NextFragment::Line; + self.seen_bytes += literal_length as usize; + } + + Ok(()) + } +} + +pub enum ReceiveEvent { + DecodingSuccess(C::Message<'static>), + DecodingFailure(C::Error<'static>), + ExpectedCrlfGotLf, +} + +// The next fragment that will be read... +#[derive(Clone, Copy, Default)] +enum NextFragment { + // ... is a line. + // + // Note: A message always starts (and ends) with a line. + #[default] + Line, + // ... is a literal with the given length. + Literal { + length: u32, + }, +} + +// A line ending for the current line was found. +struct FindCrlfResult { + // The position of the `\n` symbol + lf_position: usize, + // Is the line ending `\n` even though we expected `\r\n`? + expected_crlf_got_lf: bool, +} + +// Finds the line ending for the current line. +// Depending on `crlf_relaxed` the accepted line ending is `\n` (true) or `\r\n` (false). +fn find_crlf(buf: &[u8], crlf_relaxed: bool) -> Option { + let lf_position = buf.iter().position(|item| *item == b'\n')?; + let expected_crlf_got_lf = !crlf_relaxed && buf[lf_position.saturating_sub(1)] != b'\r'; + Some(FindCrlfResult { + lf_position, + expected_crlf_got_lf, + }) +} diff --git a/src/send.rs b/src/send.rs new file mode 100644 index 00000000..31a381b4 --- /dev/null +++ b/src/send.rs @@ -0,0 +1,236 @@ +use std::collections::VecDeque; + +use bytes::BytesMut; +use imap_codec::{ + encode::{Encoder, Fragment}, + imap_types::command::Command, + CommandCodec, +}; +use tokio::io::AsyncWriteExt; + +use crate::stream::AnyStream; + +pub struct SendCommandState { + codec: CommandCodec, + // The commands that should be send. + send_queue: VecDeque>, + // State of the command that is currently being sent. + send_progress: Option>, + // Used for writing the current command to the stream. + // Should be empty if `send_progress` is `None`. + write_buffer: BytesMut, +} + +impl SendCommandState { + pub fn new(codec: CommandCodec, write_buffer: BytesMut) -> Self { + Self { + codec, + send_queue: VecDeque::new(), + send_progress: None, + write_buffer, + } + } + + pub fn enqueue(&mut self, key: K, command: Command<'_>) { + let fragments = self.codec.encode(&command).collect(); + let entry = SendCommandQueueEntry { key, fragments }; + self.send_queue.push_back(entry); + } + + pub fn command_in_progress(&self) -> Option<&K> { + self.send_progress.as_ref().map(|x| &x.key) + } + + pub fn abort_command(&mut self) { + self.send_progress = None; + self.write_buffer.clear(); + } + + pub fn continue_command(&mut self) { + // TODO: Should we handle unexpected continues? + let Some(write_progress) = self.send_progress.as_mut() else { + return; + }; + let Some(literal_progress) = write_progress.next_literal.as_mut() else { + return; + }; + if literal_progress.received_continue { + return; + } + + literal_progress.received_continue = true; + } + + pub async fn progress( + &mut self, + stream: &mut AnyStream, + ) -> Result, tokio::io::Error> { + let progress = match self.send_progress.take() { + Some(progress) => { + // We are currently sending a command to the server. This sending process was + // previously aborted for one of two reasons: Either we needed to wait for a + // `Continue` from the server or the `Future` was dropped while sending. + progress + } + None => { + let Some(entry) = self.send_queue.pop_front() else { + // There is currently no command that need to be sent + return Ok(None); + }; + + // Start sending the next command + SendCommandProgress { + key: entry.key, + next_literal: None, + next_fragments: entry.fragments, + } + } + }; + let progress = self.send_progress.insert(progress); + + // Handle the outstanding literal first if there is one + if let Some(literal_progress) = progress.next_literal.take() { + if literal_progress.received_continue { + // We received a `Continue` from the server, we can send the literal now + self.write_buffer.extend(literal_progress.data); + } else { + // Delay this literal because we still wait for the `Continue` from the server + progress.next_literal = Some(literal_progress); + + // Make sure that the line before the literal is sent completely to the server + stream.0.write_all_buf(&mut self.write_buffer).await?; + + return Ok(None); + } + } + + // Handle the outstanding lines or literals + let need_continue = loop { + if let Some(fragment) = progress.next_fragments.pop_front() { + match fragment { + Fragment::Line { data } => { + self.write_buffer.extend(data); + } + Fragment::Literal { data, mode: _mode } => { + // TODO: Handle `LITERAL{+,-}`. + // Delay this literal because we need to wait for a `Continue` from + // the server + progress.next_literal = Some(SendCommandLiteralProgress { + data, + received_continue: false, + }); + break true; + } + } + } else { + break false; + } + }; + + // Send the bytes of the command to the server + stream.0.write_all_buf(&mut self.write_buffer).await?; + + if need_continue { + Ok(None) + } else { + // Command was sent completely + Ok(self.send_progress.take().map(|progress| progress.key)) + } + } +} + +struct SendCommandQueueEntry { + key: K, + fragments: VecDeque, +} + +struct SendCommandProgress { + key: K, + // If defined this literal need to be sent before `next_fragments`. + next_literal: Option, + // The fragments that need to be sent. + next_fragments: VecDeque, +} + +struct SendCommandLiteralProgress { + // The bytes of the literal. + data: Vec, + // Was the literal already acknowledged by a `Continue` from the server? + received_continue: bool, +} + +pub struct SendResponseState { + codec: C, + // The responses that should be sent. + send_queue: VecDeque>, + // Key of the response that is currently being sent. + send_in_progress_key: Option, + // Used for writing the current response to the stream. + // Should be empty if `send_in_progress_key` is `None`. + write_buffer: BytesMut, +} + +impl SendResponseState { + pub fn new(codec: C, write_buffer: BytesMut) -> Self { + Self { + codec, + send_queue: VecDeque::new(), + send_in_progress_key: None, + write_buffer, + } + } + + pub fn enqueue(&mut self, key: K, response: C::Message<'_>) { + let fragments = self.codec.encode(&response).collect(); + let entry = SendResponseQueueEntry { key, fragments }; + self.send_queue.push_back(entry); + } + + pub fn finish(mut self) -> BytesMut { + self.write_buffer.clear(); + self.write_buffer + } + + pub async fn progress( + &mut self, + stream: &mut AnyStream, + ) -> Result, tokio::io::Error> { + let send_in_progress_key = match self.send_in_progress_key.take() { + Some(key) => { + // We are currently sending a response. This sending process was + // previously aborted because the `Future` was dropped while sending. + key + } + None => { + let Some(entry) = self.send_queue.pop_front() else { + // There is currently no response that need to be sent + return Ok(None); + }; + + // Push the response to the write buffer + for fragment in entry.fragments { + let data = match fragment { + Fragment::Line { data } => data, + // TODO: Handle `LITERAL{+,-}`. + Fragment::Literal { data, mode: _mode } => data, + }; + self.write_buffer.extend(data); + } + + entry.key + } + }; + self.send_in_progress_key = Some(send_in_progress_key); + + // Send all bytes of current response + stream.0.write_all_buf(&mut self.write_buffer).await?; + + // response was sent completely + Ok(self.send_in_progress_key.take()) + } +} + +struct SendResponseQueueEntry { + key: K, + fragments: Vec, +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 00000000..06ff2cab --- /dev/null +++ b/src/server.rs @@ -0,0 +1,172 @@ +use bytes::BytesMut; +use imap_codec::{ + decode::CommandDecodeError, + imap_types::{ + command::Command, + response::{CommandContinuationRequest, Data, Greeting, Response, Status}, + }, + CommandCodec, GreetingCodec, ResponseCodec, +}; +use thiserror::Error; + +use crate::{ + receive::{ReceiveEvent, ReceiveState}, + send::SendResponseState, + stream::AnyStream, +}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ServerFlowOptions { + pub crlf_relaxed: bool, + pub max_literal_size: u32, +} + +pub struct ServerFlow { + stream: AnyStream, + max_literal_size: u32, + + next_response_handle: ServerFlowResponseHandle, + send_response_state: SendResponseState>, + receive_command_state: ReceiveState, +} + +impl ServerFlow { + pub async fn send_greeting( + mut stream: AnyStream, + options: ServerFlowOptions, + greeting: Greeting<'_>, + ) -> Result { + // Send greeting + let write_buffer = BytesMut::new(); + let mut send_greeting_state = + SendResponseState::new(GreetingCodec::default(), write_buffer); + send_greeting_state.enqueue((), greeting); + while let Some(()) = send_greeting_state.progress(&mut stream).await? {} + + // Successfully sent greeting, construct instance + let write_buffer = send_greeting_state.finish(); + let send_response_state = SendResponseState::new(ResponseCodec::default(), write_buffer); + let read_buffer = BytesMut::new(); + let receive_command_state = + ReceiveState::new(CommandCodec::default(), options.crlf_relaxed, read_buffer); + let server_flow = Self { + stream, + max_literal_size: options.max_literal_size, + next_response_handle: ServerFlowResponseHandle(0), + send_response_state, + receive_command_state, + }; + + Ok(server_flow) + } + + pub fn enqueue_data(&mut self, data: Data<'_>) -> ServerFlowResponseHandle { + let handle = self.next_response_handle(); + self.send_response_state + .enqueue(Some(handle), Response::Data(data)); + handle + } + + pub fn enqueue_status(&mut self, status: Status<'_>) -> ServerFlowResponseHandle { + let handle = self.next_response_handle(); + self.send_response_state + .enqueue(Some(handle), Response::Status(status)); + handle + } + + fn next_response_handle(&mut self) -> ServerFlowResponseHandle { + let handle = self.next_response_handle; + self.next_response_handle = ServerFlowResponseHandle(handle.0 + 1); + handle + } + + pub async fn progress(&mut self) -> Result { + loop { + if let Some(event) = self.progress_response().await? { + return Ok(event); + } + + if let Some(event) = self.progress_command().await? { + return Ok(event); + } + } + } + + async fn progress_response(&mut self) -> Result, ServerFlowError> { + match self.send_response_state.progress(&mut self.stream).await? { + Some(Some(handle)) => Ok(Some(ServerFlowEvent::ResponseSent { handle })), + _ => Ok(None), + } + } + + async fn progress_command(&mut self) -> Result, ServerFlowError> { + match self + .receive_command_state + .progress(&mut self.stream) + .await? + { + ReceiveEvent::DecodingSuccess(command) => { + self.receive_command_state.finish_message(); + Ok(Some(ServerFlowEvent::CommandReceived { command })) + } + ReceiveEvent::DecodingFailure(CommandDecodeError::LiteralFound { + tag, + length, + mode: _mode, + }) => { + if length > self.max_literal_size { + let discarded_bytes = self.receive_command_state.discard_message(); + + // Inform the client that the literal was rejected. + // This should never fail because the text is not Base64. + let status = Status::no(Some(tag), None, "Computer says no").unwrap(); + self.send_response_state + .enqueue(None, Response::Status(status)); + + Err(ServerFlowError::LiteralTooLong { discarded_bytes }) + } else { + self.receive_command_state.start_literal(length); + + // Inform the client that the literal was accepted. + // This should never fail because the text is not Base64. + let cont = CommandContinuationRequest::basic(None, "Please, continue").unwrap(); + self.send_response_state + .enqueue(None, Response::CommandContinuationRequest(cont)); + + Ok(None) + } + } + ReceiveEvent::DecodingFailure( + CommandDecodeError::Failed | CommandDecodeError::Incomplete, + ) => { + let discarded_bytes = self.receive_command_state.discard_message(); + Err(ServerFlowError::MalformedMessage { discarded_bytes }) + } + ReceiveEvent::ExpectedCrlfGotLf => { + let discarded_bytes = self.receive_command_state.discard_message(); + Err(ServerFlowError::ExpectedCrlfGotLf { discarded_bytes }) + } + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct ServerFlowResponseHandle(u64); + +#[derive(Debug)] +pub enum ServerFlowEvent { + ResponseSent { handle: ServerFlowResponseHandle }, + CommandReceived { command: Command<'static> }, +} + +#[derive(Debug, Error)] +pub enum ServerFlowError { + #[error(transparent)] + Io(#[from] tokio::io::Error), + #[error("Expected `\\r\\n`, got `\\n`")] + ExpectedCrlfGotLf { discarded_bytes: Box<[u8]> }, + #[error("Received malformed message")] + MalformedMessage { discarded_bytes: Box<[u8]> }, + #[error("Literal was rejected because it was too long")] + LiteralTooLong { discarded_bytes: Box<[u8]> }, +} diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 00000000..7888ad6b --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,17 @@ +use std::pin::Pin; + +use tokio::io::{AsyncRead, AsyncWrite}; + +// TODO: Reconsider this. Do we really need Stream + AnyStream? What is the smallest API that we need to expose? + +pub trait Stream: AsyncRead + AsyncWrite + Send {} + +impl Stream for S {} + +pub struct AnyStream(pub Pin>); + +impl AnyStream { + pub fn new(stream: S) -> Self { + Self(Box::pin(stream)) + } +}