diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs index 098328e49..8045ba108 100644 --- a/tokio-postgres/src/copy_both.rs +++ b/tokio-postgres/src/copy_both.rs @@ -69,10 +69,18 @@ enum SinkState { } pin_project! { - /// A sink for `COPY ... FROM STDIN` query data. + /// A sink & stream for `CopyBoth` replication messages /// /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is /// not, the copy will be aborted. + /// + /// The duplex can be split into the separate sink and stream with the [`split`] method. When + /// using this, they must be re-joined before finishing in order to properly complete the copy. + /// + /// Both the implementation of [`Stream`] and [`Sink`] provide access to the bytes wrapped + /// inside of the `CopyData` wrapper. + /// + /// [`split`]: Self::split pub struct CopyBothDuplex { #[pin] sender: mpsc::Sender, @@ -146,6 +154,53 @@ where pub async fn finish(mut self: Pin<&mut Self>) -> Result { future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await } + + /// Splits the streams into distinct [`Sink`] and [`Stream`] components + /// + /// Please note that there must be an eventual call to [`join`] the two components in order to + /// properly close the connection with [`finish`]; no corresponding method exists for the two + /// halves alone. + /// + /// [`join`]: Self::join + /// [`finish`]: Self::finish + pub fn split(self) -> (Sender, Receiver) { + let send = Sender { + sender: self.sender, + buf: self.buf, + state: self.state, + marker: PhantomData, + closed: false, + }; + + let recv = Receiver { + responses: self.responses, + }; + + (send, recv) + } + + /// Joins the two halves of a `CopyBothDuplex` after a call to [`split`] + /// + /// Note: We do not check that the sender and recevier originated from the same + /// [`CopyBothDuplex`]. If they did not, unexpected behavior *will* occur. + /// + /// ## Panics + /// + /// If the sender has already been closed, this function will panic. + /// + /// [`split`]: Self::split + pub fn join(send: Sender, recv: Receiver) -> Self { + assert!(!send.closed); + + CopyBothDuplex { + sender: send.sender, + responses: recv.responses, + buf: send.buf, + state: send.state, + _p: PhantomPinned, + _p2: PhantomData, + } + } } impl Stream for CopyBothDuplex { @@ -157,6 +212,7 @@ impl Stream for CopyBothDuplex { match ready!(this.responses.poll_next(cx)?) { Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), Message::CopyDone => Poll::Ready(None), + Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))), _ => Poll::Ready(Some(Err(Error::unexpected_message()))), } } @@ -220,6 +276,107 @@ where } } +pin_project! { + /// The receiving half of a [`CopyBothDuplex`] + /// + /// Receiving the next message is done through the [`Stream`] implementation. + pub struct Receiver { + responses: Responses, + } +} + +pin_project! { + /// The sending half of a [`CopyBothDuplex`] + /// + /// Sending each message is done through the [`Sink`] implementation. + pub struct Sender { + #[pin] + sender: mpsc::Sender, + buf: BytesMut, + state: SinkState, + marker: PhantomData, + // True iff the sink has been closed. Causes further operations to panic. + closed: bool, + } +} + +impl Stream for Receiver { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), + Message::CopyDone => Poll::Ready(None), + Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} + +impl Sink for Sender +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + assert!(!self.closed); + + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .as_mut() + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed())?; + } + + this.sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + // Closing the sink "normally" will just abort the copy. + fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.closed = true; + Poll::Ready(Ok(())) + } +} + pub async fn copy_both_simple( client: &InnerClient, query: &str, diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index bfa52ed0d..3c1cec5d1 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -119,7 +119,7 @@ pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; pub use crate::connection::Connection; -pub use crate::copy_both::CopyBothDuplex; +pub use crate::copy_both::{CopyBothDuplex, Receiver as CopyBothStream, Sender as CopyBothSink}; pub use crate::copy_in::CopyInSink; pub use crate::copy_out::CopyOutStream; use crate::error::DbError;