-
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0a89b5e
commit c5f696d
Showing
7 changed files
with
794 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
use bounded_static::ToBoundedStatic; | ||
use bytes::BytesMut; | ||
use imap_codec::{ | ||
decode::{GreetingDecodeError, ResponseDecodeError}, | ||
imap_types::{ | ||
command::Command, | ||
core::Tag, | ||
response::{Data, Greeting, Response, Status}, | ||
}, | ||
CommandCodec, GreetingCodec, ResponseCodec, | ||
}; | ||
use thiserror::Error; | ||
|
||
use crate::{ | ||
receive::{ReceiveEvent, ReceiveState}, | ||
send::SendCommandState, | ||
stream::AnyStream, | ||
}; | ||
|
||
#[derive(Debug, Clone, Copy, PartialEq)] | ||
pub struct ClientFlowOptions { | ||
pub crlf_relaxed: bool, | ||
} | ||
|
||
pub struct ClientFlow { | ||
stream: AnyStream, | ||
|
||
next_command_handle: ClientFlowCommandHandle, | ||
send_command_state: SendCommandState<(Tag<'static>, ClientFlowCommandHandle)>, | ||
receive_response_state: ReceiveState<ResponseCodec>, | ||
} | ||
|
||
impl ClientFlow { | ||
pub async fn receive_greeting( | ||
mut stream: AnyStream, | ||
options: ClientFlowOptions, | ||
) -> Result<(Self, Greeting<'static>), ClientFlowError> { | ||
// Receive greeting | ||
let read_buffer = BytesMut::new(); | ||
let mut receive_greeting_state = | ||
ReceiveState::new(GreetingCodec::default(), options.crlf_relaxed, read_buffer); | ||
let greeting = match receive_greeting_state.progress(&mut stream).await? { | ||
ReceiveEvent::DecodingSuccess(greeting) => { | ||
receive_greeting_state.finish_message(); | ||
greeting | ||
} | ||
ReceiveEvent::DecodingFailure( | ||
GreetingDecodeError::Failed | GreetingDecodeError::Incomplete, | ||
) => { | ||
let discarded_bytes = receive_greeting_state.discard_message(); | ||
return Err(ClientFlowError::MalformedMessage { discarded_bytes }); | ||
} | ||
ReceiveEvent::ExpectedCrlfGotLf => { | ||
let discarded_bytes = receive_greeting_state.discard_message(); | ||
return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes }); | ||
} | ||
}; | ||
|
||
// Successfully received greeting, create instance. | ||
let write_buffer = BytesMut::new(); | ||
let send_command_state = SendCommandState::new(CommandCodec::default(), write_buffer); | ||
let read_buffer = receive_greeting_state.finish(); | ||
let receive_response_state = | ||
ReceiveState::new(ResponseCodec::new(), options.crlf_relaxed, read_buffer); | ||
let client_flow = Self { | ||
stream, | ||
next_command_handle: ClientFlowCommandHandle(0), | ||
send_command_state, | ||
receive_response_state, | ||
}; | ||
|
||
Ok((client_flow, greeting)) | ||
} | ||
|
||
pub fn enqueue_command(&mut self, command: Command<'_>) -> ClientFlowCommandHandle { | ||
let handle = self.next_command_handle; | ||
self.next_command_handle = ClientFlowCommandHandle(handle.0 + 1); | ||
let tag = command.tag.to_static(); | ||
self.send_command_state.enqueue((tag, handle), command); | ||
handle | ||
} | ||
|
||
pub async fn progress(&mut self) -> Result<ClientFlowEvent, ClientFlowError> { | ||
loop { | ||
if let Some(event) = self.progress_command().await? { | ||
return Ok(event); | ||
} | ||
|
||
if let Some(event) = self.progress_response().await? { | ||
return Ok(event); | ||
} | ||
} | ||
} | ||
|
||
async fn progress_command(&mut self) -> Result<Option<ClientFlowEvent>, ClientFlowError> { | ||
match self.send_command_state.progress(&mut self.stream).await? { | ||
Some((tag, handle)) => Ok(Some(ClientFlowEvent::CommandSent { tag, handle })), | ||
None => Ok(None), | ||
} | ||
} | ||
|
||
async fn progress_response(&mut self) -> Result<Option<ClientFlowEvent>, ClientFlowError> { | ||
let event = loop { | ||
let response = match self | ||
.receive_response_state | ||
.progress(&mut self.stream) | ||
.await? | ||
{ | ||
ReceiveEvent::DecodingSuccess(response) => { | ||
self.receive_response_state.finish_message(); | ||
response | ||
} | ||
ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length }) => { | ||
// The client must accept the literal in any case. | ||
self.receive_response_state.start_literal(length); | ||
continue; | ||
} | ||
ReceiveEvent::DecodingFailure( | ||
ResponseDecodeError::Failed | ResponseDecodeError::Incomplete, | ||
) => { | ||
let discarded_bytes = self.receive_response_state.discard_message(); | ||
return Err(ClientFlowError::MalformedMessage { discarded_bytes }); | ||
} | ||
ReceiveEvent::ExpectedCrlfGotLf => { | ||
let discarded_bytes = self.receive_response_state.discard_message(); | ||
return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes }); | ||
} | ||
}; | ||
|
||
match response { | ||
Response::Status(status) => { | ||
self.maybe_abort_command(&status); | ||
break Some(ClientFlowEvent::StatusReceived { status }); | ||
} | ||
Response::Data(data) => break Some(ClientFlowEvent::DataReceived { data }), | ||
Response::CommandContinuationRequest(_) => { | ||
self.send_command_state.continue_command(); | ||
break None; | ||
} | ||
} | ||
}; | ||
|
||
Ok(event) | ||
} | ||
|
||
fn maybe_abort_command(&mut self, status: &Status) { | ||
let Some((command_tag, _)) = self.send_command_state.command_in_progress() else { | ||
return; | ||
}; | ||
|
||
match status { | ||
Status::Bad { | ||
tag: Some(status_tag), | ||
.. | ||
} if status_tag == command_tag => { | ||
self.send_command_state.abort_command(); | ||
} | ||
_ => (), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] | ||
pub struct ClientFlowCommandHandle(u64); | ||
|
||
#[derive(Debug)] | ||
pub enum ClientFlowEvent { | ||
CommandSent { | ||
tag: Tag<'static>, | ||
handle: ClientFlowCommandHandle, | ||
}, | ||
DataReceived { | ||
data: Data<'static>, | ||
}, | ||
StatusReceived { | ||
status: Status<'static>, | ||
}, | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub enum ClientFlowError { | ||
#[error(transparent)] | ||
Io(#[from] tokio::io::Error), | ||
#[error("Expected `\\r\\n`, got `\\n`")] | ||
ExpectedCrlfGotLf { discarded_bytes: Box<[u8]> }, | ||
#[error("Received malformed message")] | ||
MalformedMessage { discarded_bytes: Box<[u8]> }, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,5 @@ | ||
pub fn add(left: usize, right: usize) -> usize { | ||
left + right | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn it_works() { | ||
let result = add(2, 2); | ||
assert_eq!(result, 4); | ||
} | ||
} | ||
pub mod client; | ||
mod receive; | ||
mod send; | ||
pub mod server; | ||
pub mod stream; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
use bounded_static::IntoBoundedStatic; | ||
use bytes::{Buf, BytesMut}; | ||
use imap_codec::decode::Decoder; | ||
use tokio::io::AsyncReadExt; | ||
|
||
use crate::stream::AnyStream; | ||
|
||
pub struct ReceiveState<C: Decoder> { | ||
codec: C, | ||
crlf_relaxed: bool, | ||
next_fragment: NextFragment, | ||
// How many bytes in the parse buffer do we already have checked? | ||
// This is important if we need multiple attempts to read from the underlying | ||
// stream before the message is completely received. | ||
seen_bytes: usize, | ||
// Used for reading the current message from the stream. | ||
// Its length should always be equal to or greater than `seen_bytes`. | ||
read_buffer: BytesMut, | ||
} | ||
|
||
impl<C: Decoder> ReceiveState<C> { | ||
pub fn new(codec: C, crlf_relaxed: bool, read_buffer: BytesMut) -> Self { | ||
Self { | ||
codec, | ||
crlf_relaxed, | ||
next_fragment: NextFragment::default(), | ||
seen_bytes: 0, | ||
read_buffer, | ||
} | ||
} | ||
|
||
pub fn start_literal(&mut self, length: u32) { | ||
self.next_fragment = NextFragment::Literal { length }; | ||
self.read_buffer.reserve(length as usize); | ||
} | ||
|
||
pub fn finish_message(&mut self) { | ||
self.read_buffer.advance(self.seen_bytes); | ||
self.seen_bytes = 0; | ||
self.next_fragment = NextFragment::default(); | ||
} | ||
|
||
pub fn discard_message(&mut self) -> Box<[u8]> { | ||
let discarded_bytes = self.read_buffer[..self.seen_bytes].into(); | ||
self.finish_message(); | ||
discarded_bytes | ||
} | ||
|
||
pub fn finish(self) -> BytesMut { | ||
self.read_buffer | ||
} | ||
|
||
pub async fn progress( | ||
&mut self, | ||
stream: &mut AnyStream, | ||
) -> Result<ReceiveEvent<C>, tokio::io::Error> | ||
where | ||
for<'a> C::Message<'a>: IntoBoundedStatic<Static = C::Message<'static>>, | ||
for<'a> C::Error<'a>: IntoBoundedStatic<Static = C::Error<'static>>, | ||
{ | ||
loop { | ||
match self.next_fragment { | ||
NextFragment::Line => { | ||
if let Some(event) = self.progress_line(stream).await? { | ||
return Ok(event); | ||
} | ||
} | ||
NextFragment::Literal { length } => { | ||
self.progress_literal(stream, length).await?; | ||
} | ||
}; | ||
} | ||
} | ||
|
||
async fn progress_line( | ||
&mut self, | ||
stream: &mut AnyStream, | ||
) -> Result<Option<ReceiveEvent<C>>, tokio::io::Error> | ||
where | ||
for<'a> C::Message<'a>: IntoBoundedStatic<Static = C::Message<'static>>, | ||
for<'a> C::Error<'a>: IntoBoundedStatic<Static = C::Error<'static>>, | ||
{ | ||
// TODO: If the line is really long and we need multiple attempts to receive it, then this is O(n^2). | ||
// This could be fixed by setting seen bytes in the None case | ||
let crlf_result = match find_crlf(&self.read_buffer[self.seen_bytes..], self.crlf_relaxed) { | ||
Some(crlf_result) => crlf_result, | ||
None => { | ||
// No full line received yet, more data needed. | ||
stream.0.read_buf(&mut self.read_buffer).await?; | ||
return Ok(None); | ||
} | ||
}; | ||
|
||
// Mark the all bytes of the current line as seen. | ||
self.seen_bytes += crlf_result.lf_position + 1; | ||
|
||
if crlf_result.expected_crlf_got_lf { | ||
return Ok(Some(ReceiveEvent::ExpectedCrlfGotLf)); | ||
} | ||
|
||
// Try to parse the whole message from the start (including the new line). | ||
// TODO: If the message is really long and we need multiple attempts to receive it, then this is O(n^2) | ||
// IMO this can be only fixed by using a generator-like decoder | ||
match self.codec.decode(&self.read_buffer[..self.seen_bytes]) { | ||
Ok((remaining, message)) => { | ||
assert!(remaining.is_empty()); | ||
Ok(Some(ReceiveEvent::DecodingSuccess(message.into_static()))) | ||
} | ||
Err(error) => Ok(Some(ReceiveEvent::DecodingFailure(error.into_static()))), | ||
} | ||
} | ||
|
||
async fn progress_literal( | ||
&mut self, | ||
stream: &mut AnyStream, | ||
literal_length: u32, | ||
) -> Result<(), tokio::io::Error> { | ||
let unseen_bytes = self.read_buffer.len() - self.seen_bytes; | ||
|
||
if unseen_bytes < literal_length as usize { | ||
// We did not receive enough bytes for the literal yet. | ||
stream.0.read_buf(&mut self.read_buffer).await?; | ||
} else { | ||
// We received enough bytes for the literal. | ||
// Now we can continue reading the next line. | ||
self.next_fragment = NextFragment::Line; | ||
self.seen_bytes += literal_length as usize; | ||
} | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
pub enum ReceiveEvent<C: Decoder> { | ||
DecodingSuccess(C::Message<'static>), | ||
DecodingFailure(C::Error<'static>), | ||
ExpectedCrlfGotLf, | ||
} | ||
|
||
// The next fragment that will be read... | ||
#[derive(Clone, Copy, Default)] | ||
enum NextFragment { | ||
// ... is a line. | ||
// | ||
// Note: A message always starts (and ends) with a line. | ||
#[default] | ||
Line, | ||
// ... is a literal with the given length. | ||
Literal { | ||
length: u32, | ||
}, | ||
} | ||
|
||
// A line ending for the current line was found. | ||
struct FindCrlfResult { | ||
// The position of the `\n` symbol | ||
lf_position: usize, | ||
// Is the line ending `\n` even though we expected `\r\n`? | ||
expected_crlf_got_lf: bool, | ||
} | ||
|
||
// Finds the line ending for the current line. | ||
// Depending on `crlf_relaxed` the accepted line ending is `\n` (true) or `\r\n` (false). | ||
fn find_crlf(buf: &[u8], crlf_relaxed: bool) -> Option<FindCrlfResult> { | ||
let lf_position = buf.iter().position(|item| *item == b'\n')?; | ||
let expected_crlf_got_lf = !crlf_relaxed && buf[lf_position.saturating_sub(1)] != b'\r'; | ||
Some(FindCrlfResult { | ||
lf_position, | ||
expected_crlf_got_lf, | ||
}) | ||
} |
Oops, something went wrong.