Skip to content

Commit

Permalink
Implement first prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
jakoschiko committed Oct 8, 2023
1 parent 0a89b5e commit c5f696d
Show file tree
Hide file tree
Showing 7 changed files with 794 additions and 16 deletions.
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ name = "imap-flow"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
bounded-static = "0.5.0"
bytes = "1.5.0"
imap-codec = { version = "1.0.0", features = ["quirk_crlf_relaxed", "bounded-static"] }
thiserror = "1.0.49"
tokio = { version = "1.32.0", features = ["io-util"] }
188 changes: 188 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
use bounded_static::ToBoundedStatic;
use bytes::BytesMut;
use imap_codec::{
decode::{GreetingDecodeError, ResponseDecodeError},
imap_types::{
command::Command,
core::Tag,
response::{Data, Greeting, Response, Status},
},
CommandCodec, GreetingCodec, ResponseCodec,
};
use thiserror::Error;

use crate::{
receive::{ReceiveEvent, ReceiveState},
send::SendCommandState,
stream::AnyStream,
};

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ClientFlowOptions {
pub crlf_relaxed: bool,
}

pub struct ClientFlow {
stream: AnyStream,

next_command_handle: ClientFlowCommandHandle,
send_command_state: SendCommandState<(Tag<'static>, ClientFlowCommandHandle)>,
receive_response_state: ReceiveState<ResponseCodec>,
}

impl ClientFlow {
pub async fn receive_greeting(
mut stream: AnyStream,
options: ClientFlowOptions,
) -> Result<(Self, Greeting<'static>), ClientFlowError> {
// Receive greeting
let read_buffer = BytesMut::new();
let mut receive_greeting_state =
ReceiveState::new(GreetingCodec::default(), options.crlf_relaxed, read_buffer);
let greeting = match receive_greeting_state.progress(&mut stream).await? {
ReceiveEvent::DecodingSuccess(greeting) => {
receive_greeting_state.finish_message();
greeting
}
ReceiveEvent::DecodingFailure(
GreetingDecodeError::Failed | GreetingDecodeError::Incomplete,
) => {
let discarded_bytes = receive_greeting_state.discard_message();
return Err(ClientFlowError::MalformedMessage { discarded_bytes });
}
ReceiveEvent::ExpectedCrlfGotLf => {
let discarded_bytes = receive_greeting_state.discard_message();
return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes });
}
};

// Successfully received greeting, create instance.
let write_buffer = BytesMut::new();
let send_command_state = SendCommandState::new(CommandCodec::default(), write_buffer);
let read_buffer = receive_greeting_state.finish();
let receive_response_state =
ReceiveState::new(ResponseCodec::new(), options.crlf_relaxed, read_buffer);
let client_flow = Self {
stream,
next_command_handle: ClientFlowCommandHandle(0),
send_command_state,
receive_response_state,
};

Ok((client_flow, greeting))
}

pub fn enqueue_command(&mut self, command: Command<'_>) -> ClientFlowCommandHandle {
let handle = self.next_command_handle;
self.next_command_handle = ClientFlowCommandHandle(handle.0 + 1);
let tag = command.tag.to_static();
self.send_command_state.enqueue((tag, handle), command);
handle
}

pub async fn progress(&mut self) -> Result<ClientFlowEvent, ClientFlowError> {
loop {
if let Some(event) = self.progress_command().await? {
return Ok(event);
}

if let Some(event) = self.progress_response().await? {
return Ok(event);
}
}
}

async fn progress_command(&mut self) -> Result<Option<ClientFlowEvent>, ClientFlowError> {
match self.send_command_state.progress(&mut self.stream).await? {
Some((tag, handle)) => Ok(Some(ClientFlowEvent::CommandSent { tag, handle })),
None => Ok(None),
}
}

async fn progress_response(&mut self) -> Result<Option<ClientFlowEvent>, ClientFlowError> {
let event = loop {
let response = match self
.receive_response_state
.progress(&mut self.stream)
.await?
{
ReceiveEvent::DecodingSuccess(response) => {
self.receive_response_state.finish_message();
response
}
ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length }) => {
// The client must accept the literal in any case.
self.receive_response_state.start_literal(length);
continue;
}
ReceiveEvent::DecodingFailure(
ResponseDecodeError::Failed | ResponseDecodeError::Incomplete,
) => {
let discarded_bytes = self.receive_response_state.discard_message();
return Err(ClientFlowError::MalformedMessage { discarded_bytes });
}
ReceiveEvent::ExpectedCrlfGotLf => {
let discarded_bytes = self.receive_response_state.discard_message();
return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes });
}
};

match response {
Response::Status(status) => {
self.maybe_abort_command(&status);
break Some(ClientFlowEvent::StatusReceived { status });
}
Response::Data(data) => break Some(ClientFlowEvent::DataReceived { data }),
Response::CommandContinuationRequest(_) => {
self.send_command_state.continue_command();
break None;
}
}
};

Ok(event)
}

fn maybe_abort_command(&mut self, status: &Status) {
let Some((command_tag, _)) = self.send_command_state.command_in_progress() else {
return;
};

match status {
Status::Bad {
tag: Some(status_tag),
..
} if status_tag == command_tag => {
self.send_command_state.abort_command();
}
_ => (),
}
}
}

#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct ClientFlowCommandHandle(u64);

#[derive(Debug)]
pub enum ClientFlowEvent {
CommandSent {
tag: Tag<'static>,
handle: ClientFlowCommandHandle,
},
DataReceived {
data: Data<'static>,
},
StatusReceived {
status: Status<'static>,
},
}

#[derive(Debug, Error)]
pub enum ClientFlowError {
#[error(transparent)]
Io(#[from] tokio::io::Error),
#[error("Expected `\\r\\n`, got `\\n`")]
ExpectedCrlfGotLf { discarded_bytes: Box<[u8]> },
#[error("Received malformed message")]
MalformedMessage { discarded_bytes: Box<[u8]> },
}
19 changes: 5 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}
pub mod client;
mod receive;
mod send;
pub mod server;
pub mod stream;
171 changes: 171 additions & 0 deletions src/receive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
use bounded_static::IntoBoundedStatic;
use bytes::{Buf, BytesMut};
use imap_codec::decode::Decoder;
use tokio::io::AsyncReadExt;

use crate::stream::AnyStream;

pub struct ReceiveState<C: Decoder> {
codec: C,
crlf_relaxed: bool,
next_fragment: NextFragment,
// How many bytes in the parse buffer do we already have checked?
// This is important if we need multiple attempts to read from the underlying
// stream before the message is completely received.
seen_bytes: usize,
// Used for reading the current message from the stream.
// Its length should always be equal to or greater than `seen_bytes`.
read_buffer: BytesMut,
}

impl<C: Decoder> ReceiveState<C> {
pub fn new(codec: C, crlf_relaxed: bool, read_buffer: BytesMut) -> Self {
Self {
codec,
crlf_relaxed,
next_fragment: NextFragment::default(),
seen_bytes: 0,
read_buffer,
}
}

pub fn start_literal(&mut self, length: u32) {
self.next_fragment = NextFragment::Literal { length };
self.read_buffer.reserve(length as usize);
}

pub fn finish_message(&mut self) {
self.read_buffer.advance(self.seen_bytes);
self.seen_bytes = 0;
self.next_fragment = NextFragment::default();
}

pub fn discard_message(&mut self) -> Box<[u8]> {
let discarded_bytes = self.read_buffer[..self.seen_bytes].into();
self.finish_message();
discarded_bytes
}

pub fn finish(self) -> BytesMut {
self.read_buffer
}

pub async fn progress(
&mut self,
stream: &mut AnyStream,
) -> Result<ReceiveEvent<C>, tokio::io::Error>
where
for<'a> C::Message<'a>: IntoBoundedStatic<Static = C::Message<'static>>,
for<'a> C::Error<'a>: IntoBoundedStatic<Static = C::Error<'static>>,
{
loop {
match self.next_fragment {
NextFragment::Line => {
if let Some(event) = self.progress_line(stream).await? {
return Ok(event);
}
}
NextFragment::Literal { length } => {
self.progress_literal(stream, length).await?;
}
};
}
}

async fn progress_line(
&mut self,
stream: &mut AnyStream,
) -> Result<Option<ReceiveEvent<C>>, tokio::io::Error>
where
for<'a> C::Message<'a>: IntoBoundedStatic<Static = C::Message<'static>>,
for<'a> C::Error<'a>: IntoBoundedStatic<Static = C::Error<'static>>,
{
// TODO: If the line is really long and we need multiple attempts to receive it, then this is O(n^2).
// This could be fixed by setting seen bytes in the None case
let crlf_result = match find_crlf(&self.read_buffer[self.seen_bytes..], self.crlf_relaxed) {
Some(crlf_result) => crlf_result,
None => {
// No full line received yet, more data needed.
stream.0.read_buf(&mut self.read_buffer).await?;
return Ok(None);
}
};

// Mark the all bytes of the current line as seen.
self.seen_bytes += crlf_result.lf_position + 1;

if crlf_result.expected_crlf_got_lf {
return Ok(Some(ReceiveEvent::ExpectedCrlfGotLf));
}

// Try to parse the whole message from the start (including the new line).
// TODO: If the message is really long and we need multiple attempts to receive it, then this is O(n^2)
// IMO this can be only fixed by using a generator-like decoder
match self.codec.decode(&self.read_buffer[..self.seen_bytes]) {
Ok((remaining, message)) => {
assert!(remaining.is_empty());
Ok(Some(ReceiveEvent::DecodingSuccess(message.into_static())))
}
Err(error) => Ok(Some(ReceiveEvent::DecodingFailure(error.into_static()))),
}
}

async fn progress_literal(
&mut self,
stream: &mut AnyStream,
literal_length: u32,
) -> Result<(), tokio::io::Error> {
let unseen_bytes = self.read_buffer.len() - self.seen_bytes;

if unseen_bytes < literal_length as usize {
// We did not receive enough bytes for the literal yet.
stream.0.read_buf(&mut self.read_buffer).await?;
} else {
// We received enough bytes for the literal.
// Now we can continue reading the next line.
self.next_fragment = NextFragment::Line;
self.seen_bytes += literal_length as usize;
}

Ok(())
}
}

pub enum ReceiveEvent<C: Decoder> {
DecodingSuccess(C::Message<'static>),
DecodingFailure(C::Error<'static>),
ExpectedCrlfGotLf,
}

// The next fragment that will be read...
#[derive(Clone, Copy, Default)]
enum NextFragment {
// ... is a line.
//
// Note: A message always starts (and ends) with a line.
#[default]
Line,
// ... is a literal with the given length.
Literal {
length: u32,
},
}

// A line ending for the current line was found.
struct FindCrlfResult {
// The position of the `\n` symbol
lf_position: usize,
// Is the line ending `\n` even though we expected `\r\n`?
expected_crlf_got_lf: bool,
}

// Finds the line ending for the current line.
// Depending on `crlf_relaxed` the accepted line ending is `\n` (true) or `\r\n` (false).
fn find_crlf(buf: &[u8], crlf_relaxed: bool) -> Option<FindCrlfResult> {
let lf_position = buf.iter().position(|item| *item == b'\n')?;
let expected_crlf_got_lf = !crlf_relaxed && buf[lf_position.saturating_sub(1)] != b'\r';
Some(FindCrlfResult {
lf_position,
expected_crlf_got_lf,
})
}
Loading

0 comments on commit c5f696d

Please sign in to comment.