1
1
use std:: {
2
+ convert:: TryFrom ,
2
3
future:: Future ,
3
4
io,
4
5
pin:: Pin ,
@@ -8,11 +9,10 @@ use std::{
8
9
9
10
use futures:: future:: { FutureExt , TryFutureExt } ;
10
11
use ring:: digest;
11
- use rustls:: { ClientConfig , Session } ;
12
+ use rustls:: { ClientConfig , ServerName } ;
12
13
use tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ;
13
14
use tokio_postgres:: tls:: { ChannelBinding , MakeTlsConnect , TlsConnect } ;
14
15
use tokio_rustls:: { client:: TlsStream , TlsConnector } ;
15
- use webpki:: { DNSName , DNSNameRef } ;
16
16
17
17
#[ derive( Clone ) ]
18
18
pub struct MakeRustlsConnect {
@@ -36,19 +36,21 @@ where
36
36
type Error = io:: Error ;
37
37
38
38
fn make_tls_connect ( & mut self , hostname : & str ) -> io:: Result < RustlsConnect > {
39
- DNSNameRef :: try_from_ascii_str ( hostname)
40
- . map ( |dns_name| RustlsConnect ( Some ( RustlsConnectData {
41
- hostname : dns_name. to_owned ( ) ,
42
- connector : Arc :: clone ( & self . config ) . into ( ) ,
43
- } ) ) )
39
+ ServerName :: try_from ( hostname)
40
+ . map ( |dns_name| {
41
+ RustlsConnect ( Some ( RustlsConnectData {
42
+ hostname : dns_name,
43
+ connector : Arc :: clone ( & self . config ) . into ( ) ,
44
+ } ) )
45
+ } )
44
46
. or ( Ok ( RustlsConnect ( None ) ) )
45
47
}
46
48
}
47
49
48
50
pub struct RustlsConnect ( Option < RustlsConnectData > ) ;
49
51
50
52
struct RustlsConnectData {
51
- hostname : DNSName ,
53
+ hostname : ServerName ,
52
54
connector : TlsConnector ,
53
55
}
54
56
@@ -63,10 +65,11 @@ where
63
65
fn connect ( self , stream : S ) -> Self :: Future {
64
66
match self . 0 {
65
67
None => Box :: pin ( core:: future:: ready ( Err ( io:: ErrorKind :: InvalidInput . into ( ) ) ) ) ,
66
- Some ( c) => c. connector
67
- . connect ( c. hostname . as_ref ( ) , stream)
68
+ Some ( c) => c
69
+ . connector
70
+ . connect ( c. hostname , stream)
68
71
. map_ok ( |s| RustlsStream ( Box :: pin ( s) ) )
69
- . boxed ( )
72
+ . boxed ( ) ,
70
73
}
71
74
}
72
75
}
79
82
{
80
83
fn channel_binding ( & self ) -> ChannelBinding {
81
84
let ( _, session) = self . 0 . get_ref ( ) ;
82
- match session. get_peer_certificates ( ) {
83
- Some ( certs) if certs. len ( ) > 0 => {
85
+ match session. peer_certificates ( ) {
86
+ Some ( certs) if ! certs. is_empty ( ) => {
84
87
let sha256 = digest:: digest ( & digest:: SHA256 , certs[ 0 ] . as_ref ( ) ) ;
85
88
ChannelBinding :: tls_server_end_point ( sha256. as_ref ( ) . into ( ) )
86
89
}
@@ -100,7 +103,6 @@ where
100
103
) -> Poll < tokio:: io:: Result < ( ) > > {
101
104
self . 0 . as_mut ( ) . poll_read ( cx, buf)
102
105
}
103
-
104
106
}
105
107
106
108
impl < S > AsyncWrite for RustlsStream < S >
@@ -122,7 +124,6 @@ where
122
124
fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context ) -> Poll < tokio:: io:: Result < ( ) > > {
123
125
self . 0 . as_mut ( ) . poll_shutdown ( cx)
124
126
}
125
-
126
127
}
127
128
128
129
#[ cfg( test) ]
@@ -133,12 +134,17 @@ mod tests {
133
134
async fn it_works ( ) {
134
135
env_logger:: builder ( ) . is_test ( true ) . try_init ( ) . unwrap ( ) ;
135
136
136
- let config = rustls:: ClientConfig :: new ( ) ;
137
+ let config = rustls:: ClientConfig :: builder ( )
138
+ . with_safe_defaults ( )
139
+ . with_root_certificates ( rustls:: RootCertStore :: empty ( ) )
140
+ . with_no_client_auth ( ) ;
137
141
let tls = super :: MakeRustlsConnect :: new ( config) ;
138
- let ( client, conn) =
139
- tokio_postgres:: connect ( "sslmode=require host=localhost port=5432 user=postgres" , tls)
140
- . await
141
- . expect ( "connect" ) ;
142
+ let ( client, conn) = tokio_postgres:: connect (
143
+ "sslmode=require host=localhost port=5432 user=postgres" ,
144
+ tls,
145
+ )
146
+ . await
147
+ . expect ( "connect" ) ;
142
148
tokio:: spawn ( conn. map_err ( |e| panic ! ( "{:?}" , e) ) ) ;
143
149
let stmt = client. prepare ( "SELECT 1" ) . await . expect ( "prepare" ) ;
144
150
let _ = client. query ( & stmt, & [ ] ) . await . expect ( "query" ) ;
0 commit comments