diff --git a/client/src/client.rs b/client/src/client.rs index fa7817ae..4d287971 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -15,6 +15,7 @@ use crate::{ date::DateTimeService, error::{Error, TimeoutError}, http::{self, uri, Method, Version}, + http_tunnel::HttpTunnelRequest, pool::Pool, request::RequestBuilder, resolver::ResolverService, @@ -105,13 +106,13 @@ impl Client { method!(options, OPTIONS); method!(head, HEAD); - pub fn connect(&self, url: U) -> Result, Error> + pub fn connect(&self, url: U) -> Result, Error> where uri::Uri: TryFrom, Error: From<>::Error>, { self.get(url) - .map(|req| crate::tunnel::TunnelRequest::new(req.method(Method::CONNECT))) + .map(|req| HttpTunnelRequest::new(req.method(Method::CONNECT))) } #[cfg(feature = "websocket")] diff --git a/client/src/http_tunnel.rs b/client/src/http_tunnel.rs new file mode 100644 index 00000000..846caae0 --- /dev/null +++ b/client/src/http_tunnel.rs @@ -0,0 +1,123 @@ +//! http tunnel handling. + +use core::{ + pin::Pin, + task::{ready, Context, Poll}, +}; + +use futures_core::stream::Stream; +use futures_sink::Sink; + +use super::{ + body::ResponseBody, + bytes::{Buf, Bytes, BytesMut}, + error::Error, + http::StatusCode, + tunnel::{Tunnel, TunnelRequest}, +}; + +pub type HttpTunnelRequest<'a> = TunnelRequest<'a, marker::Connect>; + +mod marker { + pub struct Connect; +} + +impl<'a> HttpTunnelRequest<'a> { + /// Send the request and wait for response asynchronously. + pub async fn send(self) -> Result>, Error> { + let res = self.req.send().await?; + + let status = res.status(); + let expect_status = StatusCode::OK; + if status != expect_status { + return Err(Error::Std(format!("expecting {expect_status}, got {status}").into())); + } + + let body = res.res.into_body(); + Ok(Tunnel::new(HttpTunnel { + buf: BytesMut::new(), + body, + })) + } +} + +pub struct HttpTunnel<'b> { + buf: BytesMut, + body: ResponseBody<'b>, +} + +impl Sink for HttpTunnel<'_> +where + M: AsRef<[u8]>, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // TODO: set up a meaningful backpressure limit for send buf. + if !self.buf.chunk().is_empty() { + >::poll_flush(self, cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> { + let inner = self.get_mut(); + inner.buf.extend_from_slice(item.as_ref()); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let inner = self.get_mut(); + + match inner.body { + #[cfg(feature = "http1")] + ResponseBody::H1(ref mut body) => { + use std::io; + use tokio::io::AsyncWrite; + while !inner.buf.chunk().is_empty() { + match ready!(Pin::new(&mut **body.conn()).poll_write(_cx, inner.buf.chunk()))? { + 0 => return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())), + n => inner.buf.advance(n), + } + } + + Pin::new(&mut **body.conn()).poll_flush(_cx).map_err(Into::into) + } + #[cfg(feature = "http2")] + ResponseBody::H2(ref mut body) => { + while !inner.buf.chunk().is_empty() { + ready!(body.poll_send_buf(&mut inner.buf, _cx))?; + } + + Poll::Ready(Ok(())) + } + _ => panic!("tunnel can only be enabled when http1 or http2 feature is also enabled"), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(>::poll_flush(self.as_mut(), cx))?; + match self.get_mut().body { + #[cfg(feature = "http1")] + ResponseBody::H1(ref mut body) => { + tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut **body.conn()), cx).map_err(Into::into) + } + #[cfg(feature = "http2")] + ResponseBody::H2(ref mut body) => { + body.send_data(Bytes::new(), true)?; + Poll::Ready(Ok(())) + } + _ => panic!("tunnel can only be enabled when http1 or http2 feature is also enabled"), + } + } +} + +impl Stream for HttpTunnel<'_> { + type Item = Result; + + #[inline] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().body).poll_next(cx).map_err(Into::into) + } +} diff --git a/client/src/lib.rs b/client/src/lib.rs index ae02a7c4..6cab1b03 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -58,6 +58,7 @@ mod h3; pub mod ws; pub mod error; +pub mod http_tunnel; pub mod middleware; pub use self::builder::ClientBuilder; diff --git a/client/src/tunnel.rs b/client/src/tunnel.rs index ab77410e..84afdcc4 100644 --- a/client/src/tunnel.rs +++ b/client/src/tunnel.rs @@ -2,7 +2,7 @@ use core::{ marker::PhantomData, ops::{Deref, DerefMut}, pin::Pin, - task::{ready, Context, Poll}, + task::{Context, Poll}, time::Duration, }; @@ -12,23 +12,16 @@ use futures_core::stream::Stream; use futures_sink::Sink; use crate::{ - body::ResponseBody, - bytes::{Buf, Bytes, BytesMut}, - error::Error, - http::{Method, StatusCode, Version}, + http::{Method, Version}, request::RequestBuilder, }; /// new type of [RequestBuilder] with extended functionality for tunnel handling. -pub struct TunnelRequest<'a, M = marker::Connect> { +pub struct TunnelRequest<'a, M> { pub(crate) req: RequestBuilder<'a>, _marker: PhantomData, } -mod marker { - pub struct Connect; -} - /// new type of [RequestBuilder] with extended functionality for tunnel handling. impl<'a, M> Deref for TunnelRequest<'a, M> { @@ -79,34 +72,18 @@ impl<'a, M> TunnelRequest<'a, M> { } } -impl<'a> TunnelRequest<'a> { - /// Send the request and wait for response asynchronously. - pub async fn send(self) -> Result, Error> { - let res = self.req.send().await?; - - let status = res.status(); - let expect_status = StatusCode::OK; - if status != expect_status { - return Err(Error::Std(format!("expecting {expect_status}, got {status}").into())); - } - - let body = res.res.into_body(); - Tunnel::try_from_body(body) - } -} - /// sender part of tunneled connection. /// [Sink] trait is used to asynchronously send message. -pub struct TunnelSink<'a, 'b>(&'a Tunnel<'b>); +pub struct TunnelSink<'a, I>(&'a Tunnel); -impl Sink for TunnelSink<'_, '_> +impl Sink for TunnelSink<'_, I> where - M: AsRef<[u8]>, + I: Sink + Unpin, { - type Error = Error; + type Error = I::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - as Sink>::poll_ready(Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()), cx) + >::poll_ready(Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()), cx) } fn start_send(self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> { @@ -114,20 +91,23 @@ where } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - as Sink>::poll_flush(Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()), cx) + >::poll_flush(Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()), cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - as Sink>::poll_close(Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()), cx) + >::poll_close(Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()), cx) } } /// sender part of tunnel connection. /// [Stream] trait is used to asynchronously receive message. -pub struct TunnelStream<'a, 'b>(&'a Tunnel<'b>); +pub struct TunnelStream<'a, I>(&'a Tunnel); -impl Stream for TunnelStream<'_, '_> { - type Item = Result; +impl Stream for TunnelStream<'_, I> +where + I: Stream + Unpin, +{ + type Item = I::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()).poll_next(cx) @@ -138,41 +118,41 @@ impl Stream for TunnelStream<'_, '_> { /// /// * This type can not do concurrent message handling which means send always block receive /// and vice versa. -pub struct Tunnel<'c> { - inner: Mutex>, +pub struct Tunnel { + pub(crate) inner: Mutex, } -impl<'a> Tunnel<'a> { - pub(crate) fn try_from_body(body: ResponseBody<'a>) -> Result { - Ok(Self { - inner: Mutex::new(TunnelInner { - buf: BytesMut::new(), - body, - }), - }) - } - +impl Tunnel +where + I: Unpin, +{ /// Split into a sink and reader pair that can be used for concurrent read/write /// message to tunnel connection. #[inline] - pub fn split(&self) -> (TunnelSink<'_, 'a>, TunnelStream<'_, 'a>) { + pub fn split(&self) -> (TunnelSink<'_, I>, TunnelStream<'_, I>) { (TunnelSink(self), TunnelStream(self)) } - fn get_mut_pinned_inner(self: Pin<&mut Self>) -> Pin<&mut TunnelInner<'a>> { + pub(crate) fn new(inner: I) -> Self { + Self { + inner: Mutex::new(inner), + } + } + + fn get_mut_pinned_inner(self: Pin<&mut Self>) -> Pin<&mut I> { Pin::new(self.get_mut().inner.get_mut().unwrap()) } } -impl Sink for Tunnel<'_> +impl Sink for Tunnel where - M: AsRef<[u8]>, + I: Sink + Unpin, { - type Error = Error; + type Error = I::Error; #[inline] fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - as Sink>::poll_ready(self.get_mut_pinned_inner(), cx) + >::poll_ready(self.get_mut_pinned_inner(), cx) } #[inline] @@ -182,101 +162,23 @@ where #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - as Sink>::poll_flush(self.get_mut_pinned_inner(), cx) + >::poll_flush(self.get_mut_pinned_inner(), cx) } #[inline] fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - as Sink>::poll_close(self.get_mut_pinned_inner(), cx) + >::poll_close(self.get_mut_pinned_inner(), cx) } } -impl Stream for Tunnel<'_> { - type Item = Result; - - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut_pinned_inner().poll_next(cx) - } -} - -struct TunnelInner<'b> { - buf: BytesMut, - body: ResponseBody<'b>, -} - -impl Sink for TunnelInner<'_> +impl Stream for Tunnel where - M: AsRef<[u8]>, + I: Stream + Unpin, { - type Error = Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // TODO: set up a meaningful backpressure limit for send buf. - if !self.buf.chunk().is_empty() { - >::poll_flush(self, cx) - } else { - Poll::Ready(Ok(())) - } - } - - fn start_send(self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> { - let inner = self.get_mut(); - inner.buf.extend_from_slice(item.as_ref()); - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - let inner = self.get_mut(); - - match inner.body { - #[cfg(feature = "http1")] - ResponseBody::H1(ref mut body) => { - use std::io; - use tokio::io::AsyncWrite; - while !inner.buf.chunk().is_empty() { - match ready!(Pin::new(&mut **body.conn()).poll_write(_cx, inner.buf.chunk()))? { - 0 => return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())), - n => inner.buf.advance(n), - } - } - - Pin::new(&mut **body.conn()).poll_flush(_cx).map_err(Into::into) - } - #[cfg(feature = "http2")] - ResponseBody::H2(ref mut body) => { - while !inner.buf.chunk().is_empty() { - ready!(body.poll_send_buf(&mut inner.buf, _cx))?; - } - - Poll::Ready(Ok(())) - } - _ => panic!("tunnel can only be enabled when http1 or http2 feature is also enabled"), - } - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(>::poll_flush(self.as_mut(), cx))?; - match self.get_mut().body { - #[cfg(feature = "http1")] - ResponseBody::H1(ref mut body) => { - tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut **body.conn()), cx).map_err(Into::into) - } - #[cfg(feature = "http2")] - ResponseBody::H2(ref mut body) => { - body.send_data(Bytes::new(), true)?; - Poll::Ready(Ok(())) - } - _ => panic!("tunnel can only be enabled when http1 or http2 feature is also enabled"), - } - } -} - -impl Stream for TunnelInner<'_> { - type Item = Result; + type Item = I::Item; #[inline] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().body).poll_next(cx).map_err(Into::into) + self.get_mut_pinned_inner().poll_next(cx) } } diff --git a/client/src/ws.rs b/client/src/ws.rs index 22c4eaf2..aa05bbb6 100644 --- a/client/src/ws.rs +++ b/client/src/ws.rs @@ -7,8 +7,6 @@ use core::{ task::{ready, Context, Poll}, }; -use std::sync::Mutex; - use futures_core::stream::Stream; use futures_sink::Sink; use http_ws::{Codec, RequestStream, WsError}; @@ -18,16 +16,31 @@ use super::{ bytes::{Buf, BytesMut}, error::Error, http::{StatusCode, Version}, - tunnel::TunnelRequest, + tunnel::{Tunnel, TunnelRequest, TunnelSink, TunnelStream}, }; -/// new type of [RequestBuilder] with extended functionality for websocket handling. -pub type WsRequest<'a> = TunnelRequest<'a, marker::WebSocket>; - mod marker { pub struct WebSocket; } +/// new type of [RequestBuilder] with extended functionality for websocket handling. +pub type WsRequest<'a> = TunnelRequest<'a, marker::WebSocket>; + +/// A unified websocket that can be used as both sender/receiver. It's is a variant of [Tunnel] +/// where it's inner stream and sink is able to produce and consume [Message] type. +/// +/// * This type can not handle concurrent message which means send always block receive and vice +/// versa. +pub type WebSocket<'a> = Tunnel>; + +/// sender part of websocket connection. +/// [Sink] trait is used to asynchronously send message. +pub type WebSocketSink<'a, 'b> = TunnelSink<'a, WebSocketTunnel<'b>>; + +/// sender part of websocket connection. +/// [Stream] trait is used to asynchronously receive message. +pub type WebSocketReader<'a, 'b> = TunnelStream<'a, WebSocketTunnel<'b>>; + impl<'a> WsRequest<'a> { /// Send the request and wait for response asynchronously. pub async fn send(self) -> Result, Error> { @@ -45,72 +58,15 @@ impl<'a> WsRequest<'a> { } let body = res.res.into_body(); - WebSocket::try_from_body(body) - } -} - -/// sender part of websocket connection. -/// [Sink] trait is used to asynchronously send message. -pub struct WebSocketSink<'a, 'b>(&'a WebSocket<'b>); - -impl Sink for WebSocketSink<'_, '_> { - type Error = Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()).poll_ready(cx) - } - - fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()).start_send(item) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()).poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()).poll_close(cx) - } -} - -/// sender part of websocket connection. -/// [Stream] trait is used to asynchronously receive message. -pub struct WebSocketReader<'a, 'b>(&'a WebSocket<'b>); - -impl Stream for WebSocketReader<'_, '_> { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut *self.get_mut().0.inner.lock().unwrap()).poll_next(cx) + Ok(WebSocket::new(WebSocketTunnel { + codec: Codec::new().client_mode(), + send_buf: BytesMut::new(), + recv_stream: RequestStream::with_codec(body, Codec::new().client_mode()), + })) } } -/// A unified websocket that can be used as both sender/receiver. -/// -/// * This type can not do concurrent message handling which means send always block receive -/// and vice versa. -pub struct WebSocket<'c> { - inner: Mutex>, -} - impl<'a> WebSocket<'a> { - pub(crate) fn try_from_body(body: ResponseBody<'a>) -> Result { - Ok(Self { - inner: Mutex::new(WebSocketInner { - codec: Codec::new().client_mode(), - send_buf: BytesMut::new(), - recv_stream: RequestStream::with_codec(body, Codec::new().client_mode()), - }), - }) - } - - /// Split into a sink and reader pair that can be used for concurrent read/write - /// message to websocket connection. - #[inline] - pub fn split(&self) -> (WebSocketSink<'_, 'a>, WebSocketReader<'_, 'a>) { - (WebSocketSink(self), WebSocketReader(self)) - } - /// Set max message size. /// /// By default max size is set to 64kB. @@ -121,52 +77,15 @@ impl<'a> WebSocket<'a> { *recv_codec = recv_codec.set_max_size(size); self } - - fn get_mut_pinned_inner(self: Pin<&mut Self>) -> Pin<&mut WebSocketInner<'a>> { - Pin::new(self.get_mut().inner.get_mut().unwrap()) - } -} - -impl Sink for WebSocket<'_> { - type Error = Error; - - #[inline] - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut_pinned_inner().poll_ready(cx) - } - - #[inline] - fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - self.get_mut_pinned_inner().start_send(item) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut_pinned_inner().poll_flush(cx) - } - - #[inline] - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut_pinned_inner().poll_close(cx) - } -} - -impl Stream for WebSocket<'_> { - type Item = Result; - - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut_pinned_inner().poll_next(cx) - } } -struct WebSocketInner<'b> { +pub struct WebSocketTunnel<'b> { codec: Codec, send_buf: BytesMut, recv_stream: RequestStream>, } -impl Sink for WebSocketInner<'_> { +impl Sink for WebSocketTunnel<'_> { type Error = Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -230,7 +149,7 @@ impl Sink for WebSocketInner<'_> { } } -impl Stream for WebSocketInner<'_> { +impl Stream for WebSocketTunnel<'_> { type Item = Result; #[inline]