Skip to content

Commit

Permalink
add async io adaptor. (#965)
Browse files Browse the repository at this point in the history
* add async io adaptor.

* rename.

* remove default impl of AsyncRead/Write.

* fix service impl.

* bump xitca-io and xitca-server versioning.

* fix postgres build.

* fix doc test.
  • Loading branch information
fakeshadow authored Mar 4, 2024
1 parent 6033cae commit 9852c36
Show file tree
Hide file tree
Showing 23 changed files with 220 additions and 356 deletions.
1 change: 1 addition & 0 deletions http/CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- `util::service::router::RouterGen` is renamed to `RouteGen`. It's API is shrunk to generating route service only. For route path generating please reference `util::service::router::PathGen`.
- `body::Either` doesn't expose it's enum variants in public API anymore.
- relax `Stream::Item` associated type when impl on `body::BoxBody::new` and `body::ResponseBody::boxed_stream` types. Instead of requiring the stream to yield `Ok<Bytes>` it now accepts types `Ok<impl Into<Bytes>>`.
- update `xitca-io` to `0.2.0`.
- update `xitca-tls` to `0.2.0`.

# 0.3.0
Expand Down
6 changes: 3 additions & 3 deletions http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ http2 = ["h2", "fnv", "futures-util/alloc", "runtime", "slab"]
# http3 specific feature.
http3 = ["xitca-io/http3", "futures-util/alloc", "h3", "h3-quinn", "runtime"]
# openssl as server side tls.
openssl = ["dep:openssl", "runtime"]
openssl = ["dep:openssl", "xitca-tls", "runtime"]
# rustls as server side tls.
rustls = ["xitca-tls/rustls", "runtime"]
# rustls as server side tls.
Expand All @@ -34,7 +34,7 @@ io-uring = ["xitca-io/runtime-uring", "tokio-uring"]
router = ["xitca-router"]

[dependencies]
xitca-io = "0.1"
xitca-io = "0.2"
xitca-service = { version = "0.1", features = ["alloc", "std"] }
xitca-unsafe-collection = { version = "0.1.1", features = ["bytes"] }

Expand Down Expand Up @@ -81,7 +81,7 @@ socket2 = { version = "0.5.1", features = ["all"] }

[dev-dependencies]
criterion = "0.5"
xitca-server = "0.1"
xitca-server = "0.2"

[[bench]]
name = "h1_decode"
Expand Down
6 changes: 3 additions & 3 deletions http/src/h2/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::{fmt, pin::pin};
use std::net::SocketAddr;

use futures_core::Stream;
use xitca_io::io::{AsyncIo, AsyncRead, AsyncWrite};
use xitca_io::io::{AsyncIo, PollIoAdapter};
use xitca_service::Service;

use crate::{
Expand Down Expand Up @@ -36,7 +36,7 @@ where

A: Service<St, Response = TlsSt>,
St: AsyncIo,
TlsSt: AsyncRead + AsyncWrite + Unpin,
TlsSt: AsyncIo,

HttpServiceError<S::Error, BE>: From<A::Error>,

Expand All @@ -63,7 +63,7 @@ where

let mut conn = ::h2::server::Builder::new()
.enable_connect_protocol()
.handshake(tls_stream)
.handshake(PollIoAdapter(tls_stream))
.timeout(timer.as_mut())
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::H2Handshake))??;
Expand Down
9 changes: 4 additions & 5 deletions http/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use core::{fmt, marker::PhantomData, pin::pin};

use futures_core::Stream;
use xitca_io::{
io::{AsyncIo, AsyncRead, AsyncWrite},
net::Stream as ServerStream,
net::TcpStream,
io::AsyncIo,
net::{Stream as ServerStream, TcpStream},
};
use xitca_service::{ready::ReadyService, Service};

Expand Down Expand Up @@ -75,7 +74,7 @@ impl<S, ResB, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, con
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<ResB>>,
A: Service<TcpStream>,
A::Response: AsyncIo + AsVersion + AsyncRead + AsyncWrite + Unpin,
A::Response: AsyncIo + AsVersion,
HttpServiceError<S::Error, BE>: From<A::Error>,
S::Error: fmt::Debug,
ResB: Stream<Item = Result<Bytes, BE>>,
Expand Down Expand Up @@ -131,7 +130,7 @@ where

let mut conn = ::h2::server::Builder::new()
.enable_connect_protocol()
.handshake(_tls_stream)
.handshake(xitca_io::io::PollIoAdapter(_tls_stream))
.timeout(timer.as_mut())
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::H2Handshake))??;
Expand Down
77 changes: 6 additions & 71 deletions http/src/tls/openssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use core::{
fmt::{self, Debug, Formatter},
future::Future,
pin::Pin,
task::{ready, Context, Poll},
task::{Context, Poll},
};

use std::io;
Expand All @@ -14,7 +14,7 @@ use openssl::{
error::ErrorStack,
ssl::{Error, ErrorCode, ShutdownResult, Ssl, SslStream},
};
use xitca_io::io::{AsyncIo, AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
use xitca_io::io::{AsyncIo, Interest, Ready};
use xitca_service::Service;

use crate::{http::Version, version::AsVersion};
Expand All @@ -28,7 +28,10 @@ pub struct TlsStream<Io> {
io: SslStream<Io>,
}

impl<Io> AsVersion for TlsStream<Io> {
impl<Io> AsVersion for TlsStream<Io>
where
Io: AsyncIo,
{
fn as_version(&self) -> Version {
self.io
.ssl()
Expand Down Expand Up @@ -157,74 +160,6 @@ impl<Io: AsyncIo> io::Write for TlsStream<Io> {
}
}

impl<Io> AsyncRead for TlsStream<Io>
where
Io: AsyncIo,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
ready!(this.io.get_ref().poll_ready(Interest::READABLE, cx))?;
match io::Read::read(this, buf.initialize_unfilled()) {
Ok(n) => {
buf.advance(n);
Poll::Ready(Ok(()))
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
}

impl<Io> AsyncWrite for TlsStream<Io>
where
Io: AsyncIo,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut();
ready!(this.io.get_ref().poll_ready(Interest::WRITABLE, cx))?;

match io::Write::write(this, buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
ready!(this.io.get_ref().poll_ready(Interest::WRITABLE, cx))?;

match io::Write::flush(this) {
Ok(_) => Poll::Ready(Ok(())),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncIo::poll_shutdown(self, cx)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
ready!(this.io.get_ref().poll_ready(Interest::WRITABLE, cx))?;

match io::Write::write_vectored(this, bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}

fn is_write_vectored(&self) -> bool {
self.io.get_ref().is_vectored_write()
}
}

/// Collection of 'openssl' error types.
pub enum OpensslError {
Io(io::Error),
Expand Down
114 changes: 5 additions & 109 deletions http/src/tls/rustls.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
use core::{
convert::Infallible,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use core::{convert::Infallible, fmt};

use std::{error, io, sync::Arc};

use xitca_io::io::{AsyncIo, AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
use xitca_io::io::AsyncIo;
use xitca_service::Service;
use xitca_tls::rustls::{Error, ServerConfig, ServerConnection, TlsStream as _TlsStream};

Expand All @@ -19,20 +13,14 @@ use super::error::TlsError;
pub(crate) type RustlsConfig = Arc<ServerConfig>;

/// A stream managed by rustls for tls read/write.
pub struct TlsStream<Io>
where
Io: AsyncIo,
{
inner: _TlsStream<ServerConnection, Io>,
}
pub type TlsStream<Io> = _TlsStream<ServerConnection, Io>;

impl<Io> AsVersion for TlsStream<Io>
where
Io: AsyncIo,
{
fn as_version(&self) -> Version {
self.inner
.session()
self.session()
.alpn_protocol()
.map(Self::from_alpn)
.unwrap_or(Version::HTTP_11)
Expand Down Expand Up @@ -73,99 +61,7 @@ impl<Io: AsyncIo> Service<Io> for TlsAcceptorService {

async fn call(&self, io: Io) -> Result<Self::Response, Self::Error> {
let conn = ServerConnection::new(self.acceptor.clone())?;
let inner = _TlsStream::handshake(io, conn).await?;
Ok(TlsStream { inner })
}
}

impl<Io> AsyncIo for TlsStream<Io>
where
Io: AsyncIo,
{
#[inline]
fn ready(&self, interest: Interest) -> impl Future<Output = io::Result<Ready>> + Send {
self.inner.ready(interest)
}

#[inline]
fn poll_ready(&self, interest: Interest, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
self.inner.poll_ready(interest, cx)
}

fn is_vectored_write(&self) -> bool {
self.inner.is_vectored_write()
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncIo::poll_shutdown(Pin::new(&mut self.get_mut().inner), cx)
}
}

impl<Io: AsyncIo> io::Read for TlsStream<Io> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
io::Read::read(&mut self.inner, buf)
}
}

impl<Io: AsyncIo> io::Write for TlsStream<Io> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
io::Write::write(&mut self.inner, buf)
}

#[inline]
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
io::Write::write_vectored(&mut self.inner, bufs)
}

#[inline]
fn flush(&mut self) -> io::Result<()> {
io::Write::flush(&mut self.inner)
}
}

impl<Io> AsyncRead for TlsStream<Io>
where
Io: AsyncIo,
{
#[inline]
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
}
}

impl<Io> AsyncWrite for TlsStream<Io>
where
Io: AsyncIo,
{
#[inline]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}

#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncIo::poll_shutdown(self, cx)
}

#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)
}

#[inline]
fn is_write_vectored(&self) -> bool {
self.inner.is_vectored_write()
_TlsStream::handshake(io, conn).await.map_err(Into::into)
}
}

Expand Down
4 changes: 2 additions & 2 deletions http/src/util/futures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub(crate) use queue::*;

#[cfg(any(feature = "http2", feature = "http3"))]
mod queue {
use std::future::Future;
use core::future::Future;

use futures_util::stream::{FuturesUnordered, StreamExt};

Expand All @@ -17,7 +17,7 @@ mod queue {
#[cfg(any(all(feature = "http2", feature = "io-uring"), feature = "http3"))]
pub(crate) async fn next(&mut self) -> F::Output {
if self.is_empty() {
std::future::pending().await
core::future::pending().await
} else {
self.next2().await
}
Expand Down
6 changes: 6 additions & 0 deletions io/CHANGES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# unreleased 0.2.0
## Add
- `io::PollIoAdapter` as adaptor between `io::AsyncIo` and `io::{AsyncRead, AsyncWrite}` traits.

## Remove
- `io::AsyncRead` and `io::AsyncWrite` traits impl for `net::TcpStream` and `net::UnixStream`. Please use `PollIoAdapter` when these traits are needed for any `AsyncIo` type.
2 changes: 1 addition & 1 deletion io/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "xitca-io"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
license = "Apache-2.0"
description = "async network io types and traits"
Expand Down
Loading

0 comments on commit 9852c36

Please sign in to comment.