Skip to content

Commit a5f37b7

Browse files
committed
feat: add {http1,http2}_only for auto conn
1 parent 16daef6 commit a5f37b7

File tree

1 file changed

+110
-10
lines changed

1 file changed

+110
-10
lines changed

src/server/conn/auto.rs

+110-10
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ pub struct Builder<E> {
5858
http1: http1::Builder,
5959
#[cfg(feature = "http2")]
6060
http2: http2::Builder<E>,
61+
#[cfg(any(feature = "http1", feature = "http2"))]
62+
version: Option<Version>,
6163
#[cfg(not(feature = "http2"))]
6264
_executor: E,
6365
}
@@ -84,6 +86,8 @@ impl<E> Builder<E> {
8486
http1: http1::Builder::new(),
8587
#[cfg(feature = "http2")]
8688
http2: http2::Builder::new(executor),
89+
#[cfg(any(feature = "http1", feature = "http2"))]
90+
version: None,
8791
#[cfg(not(feature = "http2"))]
8892
_executor: executor,
8993
}
@@ -101,6 +105,26 @@ impl<E> Builder<E> {
101105
Http2Builder { inner: self }
102106
}
103107

108+
/// Only accepts HTTP/2
109+
///
110+
/// Does not do anything if used with [`serve_connection_with_upgrades`]
111+
#[cfg(feature = "http2")]
112+
pub fn http2_only(mut self) -> Self {
113+
assert!(self.version.is_none());
114+
self.version = Some(Version::H2);
115+
self
116+
}
117+
118+
/// Only accepts HTTP/1
119+
///
120+
/// Does not do anything if used with [`serve_connection_with_upgrades`]
121+
#[cfg(feature = "http1")]
122+
pub fn http1_only(mut self) -> Self {
123+
assert!(self.version.is_none());
124+
self.version = Some(Version::H1);
125+
self
126+
}
127+
104128
/// Bind a connection together with a [`Service`].
105129
pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
106130
where
@@ -112,13 +136,28 @@ impl<E> Builder<E> {
112136
I: Read + Write + Unpin + 'static,
113137
E: HttpServerConnExec<S::Future, B>,
114138
{
115-
Connection {
116-
state: ConnState::ReadVersion {
139+
let state = match self.version {
140+
#[cfg(feature = "http1")]
141+
Some(Version::H1) => {
142+
let io = Rewind::new_buffered(io, Bytes::new());
143+
let conn = self.http1.serve_connection(io, service);
144+
ConnState::H1 { conn }
145+
}
146+
#[cfg(feature = "http2")]
147+
Some(Version::H2) => {
148+
let io = Rewind::new_buffered(io, Bytes::new());
149+
let conn = self.http2.serve_connection(io, service);
150+
ConnState::H2 { conn }
151+
}
152+
#[cfg(any(feature = "http1", feature = "http2"))]
153+
_ => ConnState::ReadVersion {
117154
read_version: read_version(io),
118155
builder: self,
119156
service: Some(service),
120157
},
121-
}
158+
};
159+
160+
Connection { state }
122161
}
123162

124163
/// Bind a connection together with a [`Service`], with the ability to
@@ -148,7 +187,7 @@ impl<E> Builder<E> {
148187
}
149188
}
150189

151-
#[derive(Copy, Clone)]
190+
#[derive(Copy, Clone, Debug)]
152191
enum Version {
153192
H1,
154193
H2,
@@ -865,7 +904,7 @@ mod tests {
865904
#[cfg(not(miri))]
866905
#[tokio::test]
867906
async fn http1() {
868-
let addr = start_server().await;
907+
let addr = start_server(false, false).await;
869908
let mut sender = connect_h1(addr).await;
870909

871910
let response = sender
@@ -881,7 +920,23 @@ mod tests {
881920
#[cfg(not(miri))]
882921
#[tokio::test]
883922
async fn http2() {
884-
let addr = start_server().await;
923+
let addr = start_server(false, false).await;
924+
let mut sender = connect_h2(addr).await;
925+
926+
let response = sender
927+
.send_request(Request::new(Empty::<Bytes>::new()))
928+
.await
929+
.unwrap();
930+
931+
let body = response.into_body().collect().await.unwrap().to_bytes();
932+
933+
assert_eq!(body, BODY);
934+
}
935+
936+
#[cfg(not(miri))]
937+
#[tokio::test]
938+
async fn http2_only() {
939+
let addr = start_server(false, true).await;
885940
let mut sender = connect_h2(addr).await;
886941

887942
let response = sender
@@ -894,6 +949,46 @@ mod tests {
894949
assert_eq!(body, BODY);
895950
}
896951

952+
#[cfg(not(miri))]
953+
#[tokio::test]
954+
async fn http2_only_fail_if_client_is_http1() {
955+
let addr = start_server(false, true).await;
956+
let mut sender = connect_h1(addr).await;
957+
958+
let _ = sender
959+
.send_request(Request::new(Empty::<Bytes>::new()))
960+
.await
961+
.expect_err("should fail");
962+
}
963+
964+
#[cfg(not(miri))]
965+
#[tokio::test]
966+
async fn http1_only() {
967+
let addr = start_server(true, false).await;
968+
let mut sender = connect_h1(addr).await;
969+
970+
let response = sender
971+
.send_request(Request::new(Empty::<Bytes>::new()))
972+
.await
973+
.unwrap();
974+
975+
let body = response.into_body().collect().await.unwrap().to_bytes();
976+
977+
assert_eq!(body, BODY);
978+
}
979+
980+
#[cfg(not(miri))]
981+
#[tokio::test]
982+
async fn http1_only_fail_if_client_is_http2() {
983+
let addr = start_server(true, false).await;
984+
let mut sender = connect_h2(addr).await;
985+
986+
let _ = sender
987+
.send_request(Request::new(Empty::<Bytes>::new()))
988+
.await
989+
.expect_err("should fail");
990+
}
991+
897992
#[cfg(not(miri))]
898993
#[tokio::test]
899994
async fn graceful_shutdown() {
@@ -959,7 +1054,7 @@ mod tests {
9591054
sender
9601055
}
9611056

962-
async fn start_server() -> SocketAddr {
1057+
async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr {
9631058
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
9641059
let listener = TcpListener::bind(addr).await.unwrap();
9651060

@@ -970,9 +1065,14 @@ mod tests {
9701065
let (stream, _) = listener.accept().await.unwrap();
9711066
let stream = TokioIo::new(stream);
9721067
tokio::task::spawn(async move {
973-
let _ = auto::Builder::new(TokioExecutor::new())
974-
.serve_connection(stream, service_fn(hello))
975-
.await;
1068+
let mut builder = auto::Builder::new(TokioExecutor::new());
1069+
if h1_only {
1070+
builder = builder.http1_only();
1071+
} else if h2_only {
1072+
builder = builder.http2_only();
1073+
}
1074+
1075+
builder.serve_connection(stream, service_fn(hello)).await;
9761076
});
9771077
}
9781078
});

0 commit comments

Comments
 (0)