@@ -58,6 +58,8 @@ pub struct Builder<E> {
58
58
http1 : http1:: Builder ,
59
59
#[ cfg( feature = "http2" ) ]
60
60
http2 : http2:: Builder < E > ,
61
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
62
+ version : Option < Version > ,
61
63
#[ cfg( not( feature = "http2" ) ) ]
62
64
_executor : E ,
63
65
}
@@ -84,6 +86,8 @@ impl<E> Builder<E> {
84
86
http1 : http1:: Builder :: new ( ) ,
85
87
#[ cfg( feature = "http2" ) ]
86
88
http2 : http2:: Builder :: new ( executor) ,
89
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
90
+ version : None ,
87
91
#[ cfg( not( feature = "http2" ) ) ]
88
92
_executor : executor,
89
93
}
@@ -101,6 +105,26 @@ impl<E> Builder<E> {
101
105
Http2Builder { inner : self }
102
106
}
103
107
108
+ /// Only accepts HTTP/2
109
+ ///
110
+ /// Does not do anything if used with [`serve_connection_with_upgrades`]
111
+ #[ cfg( feature = "http2" ) ]
112
+ pub fn http2_only ( mut self ) -> Self {
113
+ assert ! ( self . version. is_none( ) ) ;
114
+ self . version = Some ( Version :: H2 ) ;
115
+ self
116
+ }
117
+
118
+ /// Only accepts HTTP/1
119
+ ///
120
+ /// Does not do anything if used with [`serve_connection_with_upgrades`]
121
+ #[ cfg( feature = "http1" ) ]
122
+ pub fn http1_only ( mut self ) -> Self {
123
+ assert ! ( self . version. is_none( ) ) ;
124
+ self . version = Some ( Version :: H1 ) ;
125
+ self
126
+ }
127
+
104
128
/// Bind a connection together with a [`Service`].
105
129
pub fn serve_connection < I , S , B > ( & self , io : I , service : S ) -> Connection < ' _ , I , S , E >
106
130
where
@@ -112,13 +136,28 @@ impl<E> Builder<E> {
112
136
I : Read + Write + Unpin + ' static ,
113
137
E : HttpServerConnExec < S :: Future , B > ,
114
138
{
115
- Connection {
116
- state : ConnState :: ReadVersion {
139
+ let state = match self . version {
140
+ #[ cfg( feature = "http1" ) ]
141
+ Some ( Version :: H1 ) => {
142
+ let io = Rewind :: new_buffered ( io, Bytes :: new ( ) ) ;
143
+ let conn = self . http1 . serve_connection ( io, service) ;
144
+ ConnState :: H1 { conn }
145
+ }
146
+ #[ cfg( feature = "http2" ) ]
147
+ Some ( Version :: H2 ) => {
148
+ let io = Rewind :: new_buffered ( io, Bytes :: new ( ) ) ;
149
+ let conn = self . http2 . serve_connection ( io, service) ;
150
+ ConnState :: H2 { conn }
151
+ }
152
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
153
+ _ => ConnState :: ReadVersion {
117
154
read_version : read_version ( io) ,
118
155
builder : self ,
119
156
service : Some ( service) ,
120
157
} ,
121
- }
158
+ } ;
159
+
160
+ Connection { state }
122
161
}
123
162
124
163
/// Bind a connection together with a [`Service`], with the ability to
@@ -148,7 +187,7 @@ impl<E> Builder<E> {
148
187
}
149
188
}
150
189
151
- #[ derive( Copy , Clone ) ]
190
+ #[ derive( Copy , Clone , Debug ) ]
152
191
enum Version {
153
192
H1 ,
154
193
H2 ,
@@ -865,7 +904,7 @@ mod tests {
865
904
#[ cfg( not( miri) ) ]
866
905
#[ tokio:: test]
867
906
async fn http1 ( ) {
868
- let addr = start_server ( ) . await ;
907
+ let addr = start_server ( false , false ) . await ;
869
908
let mut sender = connect_h1 ( addr) . await ;
870
909
871
910
let response = sender
@@ -881,7 +920,23 @@ mod tests {
881
920
#[ cfg( not( miri) ) ]
882
921
#[ tokio:: test]
883
922
async fn http2 ( ) {
884
- let addr = start_server ( ) . await ;
923
+ let addr = start_server ( false , false ) . await ;
924
+ let mut sender = connect_h2 ( addr) . await ;
925
+
926
+ let response = sender
927
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
928
+ . await
929
+ . unwrap ( ) ;
930
+
931
+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
932
+
933
+ assert_eq ! ( body, BODY ) ;
934
+ }
935
+
936
+ #[ cfg( not( miri) ) ]
937
+ #[ tokio:: test]
938
+ async fn http2_only ( ) {
939
+ let addr = start_server ( false , true ) . await ;
885
940
let mut sender = connect_h2 ( addr) . await ;
886
941
887
942
let response = sender
@@ -894,6 +949,46 @@ mod tests {
894
949
assert_eq ! ( body, BODY ) ;
895
950
}
896
951
952
+ #[ cfg( not( miri) ) ]
953
+ #[ tokio:: test]
954
+ async fn http2_only_fail_if_client_is_http1 ( ) {
955
+ let addr = start_server ( false , true ) . await ;
956
+ let mut sender = connect_h1 ( addr) . await ;
957
+
958
+ let _ = sender
959
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
960
+ . await
961
+ . expect_err ( "should fail" ) ;
962
+ }
963
+
964
+ #[ cfg( not( miri) ) ]
965
+ #[ tokio:: test]
966
+ async fn http1_only ( ) {
967
+ let addr = start_server ( true , false ) . await ;
968
+ let mut sender = connect_h1 ( addr) . await ;
969
+
970
+ let response = sender
971
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
972
+ . await
973
+ . unwrap ( ) ;
974
+
975
+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
976
+
977
+ assert_eq ! ( body, BODY ) ;
978
+ }
979
+
980
+ #[ cfg( not( miri) ) ]
981
+ #[ tokio:: test]
982
+ async fn http1_only_fail_if_client_is_http2 ( ) {
983
+ let addr = start_server ( true , false ) . await ;
984
+ let mut sender = connect_h2 ( addr) . await ;
985
+
986
+ let _ = sender
987
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
988
+ . await
989
+ . expect_err ( "should fail" ) ;
990
+ }
991
+
897
992
#[ cfg( not( miri) ) ]
898
993
#[ tokio:: test]
899
994
async fn graceful_shutdown ( ) {
@@ -959,7 +1054,7 @@ mod tests {
959
1054
sender
960
1055
}
961
1056
962
- async fn start_server ( ) -> SocketAddr {
1057
+ async fn start_server ( h1_only : bool , h2_only : bool ) -> SocketAddr {
963
1058
let addr: SocketAddr = ( [ 127 , 0 , 0 , 1 ] , 0 ) . into ( ) ;
964
1059
let listener = TcpListener :: bind ( addr) . await . unwrap ( ) ;
965
1060
@@ -970,9 +1065,14 @@ mod tests {
970
1065
let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
971
1066
let stream = TokioIo :: new ( stream) ;
972
1067
tokio:: task:: spawn ( async move {
973
- let _ = auto:: Builder :: new ( TokioExecutor :: new ( ) )
974
- . serve_connection ( stream, service_fn ( hello) )
975
- . await ;
1068
+ let mut builder = auto:: Builder :: new ( TokioExecutor :: new ( ) ) ;
1069
+ if h1_only {
1070
+ builder = builder. http1_only ( ) ;
1071
+ } else if h2_only {
1072
+ builder = builder. http2_only ( ) ;
1073
+ }
1074
+
1075
+ builder. serve_connection ( stream, service_fn ( hello) ) . await ;
976
1076
} ) ;
977
1077
}
978
1078
} ) ;
0 commit comments