Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ResponseSender::close shortcut. #958

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ impl ClientBuilder {
#[cfg(feature = "openssl")]
/// enable openssl as tls connector.
pub fn openssl(mut self) -> Self {
self.connector = connector::openssl(self.alpn_from_version());
self.connector = connector::openssl::connect(self.alpn_from_version());
self
}

#[cfg(feature = "rustls")]
/// enable rustls as tls connector.
pub fn rustls(mut self) -> Self {
self.connector = connector::rustls(self.alpn_from_version());
self.connector = connector::rustls::connect(self.alpn_from_version());
self
}

Expand Down
44 changes: 26 additions & 18 deletions client/src/tls/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ pub(crate) fn nop() -> Connector {
}

#[cfg(feature = "openssl")]
pub(crate) fn openssl(protocols: &[&[u8]]) -> Connector {
pub(crate) mod openssl {
use core::pin::Pin;

use openssl_crate::ssl::{SslConnector, SslMethod};
use tokio_openssl::SslStream;
use xitca_http::bytes::BufMut;

use super::*;

impl<'n> Service<(&'n str, Box<dyn Io>)> for SslConnector {
type Response = (Box<dyn Io>, Version);
type Error = Error;
Expand All @@ -72,22 +74,24 @@ pub(crate) fn openssl(protocols: &[&[u8]]) -> Connector {
}
}

let mut alpn = Vec::with_capacity(20);
for proto in protocols {
alpn.put_u8(proto.len() as u8);
alpn.put(*proto);
}
pub(crate) fn connect(protocols: &[&[u8]]) -> Connector {
let mut alpn = Vec::with_capacity(20);
for proto in protocols {
alpn.put_u8(proto.len() as u8);
alpn.put(*proto);
}

let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap();
let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap();

ssl.set_alpn_protos(&alpn)
.unwrap_or_else(|e| panic!("Can not set ALPN protocol: {e:?}"));
ssl.set_alpn_protos(&alpn)
.unwrap_or_else(|e| panic!("Can not set ALPN protocol: {e:?}"));

Box::new(ssl.build())
Box::new(ssl.build())
}
}

#[cfg(feature = "rustls")]
pub(crate) fn rustls(protocols: &[&[u8]]) -> Connector {
pub(crate) mod rustls {
use std::sync::Arc;

use rustls_pki_types::ServerName;
Expand All @@ -97,6 +101,8 @@ pub(crate) fn rustls(protocols: &[&[u8]]) -> Connector {
};
use webpki_roots::TLS_SERVER_ROOTS;

use super::*;

impl<'n> Service<(&'n str, Box<dyn Io>)> for TlsConnector {
type Response = (Box<dyn Io>, Version);
type Error = Error;
Expand All @@ -119,15 +125,17 @@ pub(crate) fn rustls(protocols: &[&[u8]]) -> Connector {
}
}

let mut root_certs = RootCertStore::empty();
pub(crate) fn connect(protocols: &[&[u8]]) -> Connector {
let mut root_certs = RootCertStore::empty();

root_certs.extend(TLS_SERVER_ROOTS.iter().cloned());
root_certs.extend(TLS_SERVER_ROOTS.iter().cloned());

let mut config = ClientConfig::builder()
.with_root_certificates(root_certs)
.with_no_client_auth();
let mut config = ClientConfig::builder()
.with_root_certificates(root_certs)
.with_no_client_auth();

config.alpn_protocols = protocols.iter().map(|p| p.to_vec()).collect();
config.alpn_protocols = protocols.iter().map(|p| p.to_vec()).collect();

Box::new(TlsConnector::from(Arc::new(config)))
Box::new(TlsConnector::from(Arc::new(config)))
}
}
1 change: 1 addition & 0 deletions http-ws/CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
## Add
- add `RequestStream::inner_mut` method for accessing inner stream type.
- add `RequestStream::codec_mut` method for accessing `Codec` type.
- add `ResponseSender::close` method for sending close message.

## Change
- reduce `stream::RequestStream`'s generic type params.
Expand Down
7 changes: 7 additions & 0 deletions http-ws/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tokio::sync::mpsc::{channel, Receiver, Sender};
use super::{
codec::{Codec, Message},
error::ProtocolError,
proto::CloseReason,
};

pin_project! {
Expand Down Expand Up @@ -239,6 +240,12 @@ impl ResponseSender {
pub fn binary(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
self.send(Message::Binary(bin.into()))
}

/// encode [Message::Close] variant and add to [ResponseStream].
/// take ownership of Self as after close message no more message can be sent to client.
pub async fn close(self, reason: Option<impl Into<CloseReason>>) -> Result<(), ProtocolError> {
self.send(Message::Close(reason.map(Into::into))).await
}
}

/// [Weak] version of [ResponseSender].
Expand Down
4 changes: 2 additions & 2 deletions http/src/util/service/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl<R, N, const M: usize> ReadyService for RouteService<R, N, M> {
}

/// Error type of Method not allow for route.
pub struct MethodNotAllowed(pub Vec<Method>);
pub struct MethodNotAllowed(pub Box<Vec<Method>>);

impl MethodNotAllowed {
/// slice of allowed methods of current route.
Expand Down Expand Up @@ -229,7 +229,7 @@ where
type Error = RouterError<R::Error>;

async fn call(&self, _: Req) -> Result<Self::Response, Self::Error> {
Err(RouterError::NotAllowed(MethodNotAllowed(Vec::new())))
Err(RouterError::NotAllowed(MethodNotAllowed(Box::default())))
}
}

Expand Down
2 changes: 1 addition & 1 deletion web/src/service/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ mod service {
Err(e) => Err(match e {
ServeError::NotFound => RouterError::Match(MatchError),
ServeError::MethodNotAllowed => {
RouterError::NotAllowed(MethodNotAllowed(vec![Method::GET, Method::HEAD]))
RouterError::NotAllowed(MethodNotAllowed(Box::new(vec![Method::GET, Method::HEAD])))
}
ServeError::Io(io) => RouterError::Service(Error::from(io)),
_ => RouterError::Service(Error::from(ErrorStatus::bad_request())),
Expand Down
Loading