diff --git a/src/client.rs b/src/client.rs index 468e12ff..e4f2ef30 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use bytes::BytesMut; use imap_codec::{ decode::{GreetingDecodeError, ResponseDecodeError}, @@ -13,11 +15,15 @@ use imap_codec::{ use thiserror::Error; use crate::{ + handle::{Handle, HandleGenerator, HandleGeneratorGenerator, RawHandle}, receive::{ReceiveEvent, ReceiveState}, send::SendCommandState, stream::{AnyStream, StreamError}, }; +static HANDLE_GENERATOR_GENERATOR: HandleGeneratorGenerator = + HandleGeneratorGenerator::new(); + #[derive(Debug, Clone, Copy, PartialEq)] pub struct ClientFlowOptions { pub crlf_relaxed: bool, @@ -36,7 +42,7 @@ impl Default for ClientFlowOptions { pub struct ClientFlow { stream: AnyStream, - handle_generator: ClientFlowCommandHandleGenerator, + handle_generator: HandleGenerator, send_command_state: SendCommandState, receive_response_state: ReceiveState, } @@ -82,7 +88,7 @@ impl ClientFlow { let client_flow = Self { stream, - handle_generator: ClientFlowCommandHandleGenerator::default(), + handle_generator: HANDLE_GENERATOR_GENERATOR.generate(), send_command_state, receive_response_state, }; @@ -218,8 +224,24 @@ impl ClientFlow { /// [`ClientFlow::enqueue_command`] it is in the process of being sent until /// [`ClientFlow::progress`] returns a [`ClientFlowEvent::CommandSent`] or /// [`ClientFlowEvent::CommandRejected`] with the corresponding handle. -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] -pub struct ClientFlowCommandHandle(u64); +#[derive(Clone, Copy, Eq, PartialEq, Hash)] +pub struct ClientFlowCommandHandle(RawHandle); + +impl Handle for ClientFlowCommandHandle { + fn from_raw(handle: RawHandle) -> Self { + Self(handle) + } +} + +// Implement a short debug representation that hides the underlying raw handle +impl Debug for ClientFlowCommandHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("ClientFlowCommandHandle") + .field(&self.0.generator_id()) + .field(&self.0.handle_id()) + .finish() + } +} #[derive(Debug)] pub enum ClientFlowEvent { @@ -270,16 +292,3 @@ pub enum ClientFlowError { #[error("Received malformed message")] MalformedMessage { discarded_bytes: Box<[u8]> }, } - -#[derive(Debug, Default)] -struct ClientFlowCommandHandleGenerator { - counter: u64, -} - -impl ClientFlowCommandHandleGenerator { - fn generate(&mut self) -> ClientFlowCommandHandle { - let handle = ClientFlowCommandHandle(self.counter); - self.counter += self.counter.wrapping_add(1); - handle - } -} diff --git a/src/handle.rs b/src/handle.rs new file mode 100644 index 00000000..2987c157 --- /dev/null +++ b/src/handle.rs @@ -0,0 +1,75 @@ +use std::{ + marker::PhantomData, + sync::atomic::{AtomicU64, Ordering}, +}; + +pub trait Handle { + fn from_raw(raw_handle: RawHandle) -> Self; +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct RawHandle { + generator_id: u64, + handle_id: u64, +} + +impl RawHandle { + pub fn generator_id(&self) -> u64 { + self.generator_id + } + + pub fn handle_id(&self) -> u64 { + self.handle_id + } +} + +#[derive(Debug)] +pub struct HandleGenerator { + /// This ID is used to bind the handles to the generator instance, i.e. it's possible to + /// distinguish handles generated by different generators. We hope that this might + /// prevent bugs when the library user is dealing with handles from different sources. + generator_id: u64, + next_handle_id: u64, + _h: PhantomData, +} + +impl HandleGenerator { + pub fn generate(&mut self) -> H { + let handle_id = self.next_handle_id; + self.next_handle_id += self.next_handle_id.wrapping_add(1); + + H::from_raw(RawHandle { + generator_id: self.generator_id, + handle_id, + }) + } +} + +#[derive(Debug)] +pub struct HandleGeneratorGenerator { + next_handle_generator_id: AtomicU64, + _h: PhantomData, +} + +impl HandleGeneratorGenerator { + pub const fn new() -> Self { + Self { + next_handle_generator_id: AtomicU64::new(0), + _h: PhantomData, + } + } + + pub fn generate(&self) -> HandleGenerator { + // There is no synchronization required and we only care about each thread seeing a + // unique value. + let generator_id = self + .next_handle_generator_id + .fetch_add(1, Ordering::Relaxed); + + HandleGenerator { + generator_id, + next_handle_id: 0, + _h: PhantomData, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 2d80863c..92291509 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![forbid(unsafe_code)] #![deny(missing_debug_implementations)] pub mod client; +mod handle; mod receive; mod send; pub mod server; diff --git a/src/server.rs b/src/server.rs index 8ce7c4fa..a9f7dbb4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use bytes::BytesMut; use imap_codec::{ decode::CommandDecodeError, @@ -11,11 +13,15 @@ use imap_codec::{ use thiserror::Error; use crate::{ + handle::{Handle, HandleGenerator, HandleGeneratorGenerator, RawHandle}, receive::{ReceiveEvent, ReceiveState}, send::SendResponseState, stream::{AnyStream, StreamError}, }; +static HANDLE_GENERATOR_GENERATOR: HandleGeneratorGenerator = + HandleGeneratorGenerator::new(); + #[derive(Debug, Clone, PartialEq)] pub struct ServerFlowOptions { pub crlf_relaxed: bool, @@ -44,7 +50,7 @@ pub struct ServerFlow { stream: AnyStream, max_literal_size: u32, - handle_generator: ServerFlowResponseHandleGenerator, + handle_generator: HandleGenerator, send_response_state: SendResponseState>, receive_command_state: ReceiveState, @@ -78,7 +84,7 @@ impl ServerFlow { let server_flow = Self { stream, max_literal_size: options.max_literal_size, - handle_generator: ServerFlowResponseHandleGenerator::default(), + handle_generator: HANDLE_GENERATOR_GENERATOR.generate(), send_response_state, receive_command_state, literal_accept_text: options.literal_accept_text, @@ -237,8 +243,24 @@ impl ServerFlow { /// [`ServerFlow::enqueue_data`] or [`ServerFlow::enqueue_status`] it is in the process of being /// sent until [`ServerFlow::progress`] returns a [`ServerFlowEvent::ResponseSent`] with the /// corresponding handle. -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] -pub struct ServerFlowResponseHandle(u64); +#[derive(Clone, Copy, Eq, PartialEq, Hash)] +pub struct ServerFlowResponseHandle(RawHandle); + +impl Handle for ServerFlowResponseHandle { + fn from_raw(raw_handle: RawHandle) -> Self { + Self(raw_handle) + } +} + +// Implement a short debug representation that hides the underlying raw handle +impl Debug for ServerFlowResponseHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("ServerFlowResponseHandle") + .field(&self.0.generator_id()) + .field(&self.0.handle_id()) + .finish() + } +} #[derive(Debug)] pub enum ServerFlowEvent { @@ -265,16 +287,3 @@ pub enum ServerFlowError { #[error("Literal was rejected because it was too long")] LiteralTooLong { discarded_bytes: Box<[u8]> }, } - -#[derive(Debug, Default)] -struct ServerFlowResponseHandleGenerator { - counter: u64, -} - -impl ServerFlowResponseHandleGenerator { - fn generate(&mut self) -> ServerFlowResponseHandle { - let handle = ServerFlowResponseHandle(self.counter); - self.counter += self.counter.wrapping_add(1); - handle - } -}