1
1
use std:: collections:: HashMap ;
2
+ use std:: sync:: Arc ;
2
3
3
4
use crate :: error:: MutinyError ;
4
5
use crate :: storage:: MutinyStorage ;
5
6
use core:: time:: Duration ;
7
+ use gloo_net:: websocket:: futures:: WebSocket ;
6
8
use hex_conservative:: DisplayHex ;
7
9
use once_cell:: sync:: Lazy ;
8
10
use payjoin:: receive:: v2:: Enrolled ;
@@ -69,16 +71,67 @@ impl<S: MutinyStorage> PayjoinStorage for S {
69
71
}
70
72
}
71
73
72
- pub async fn fetch_ohttp_keys ( _ohttp_relay : Url , directory : Url ) -> Result < OhttpKeys , Error > {
73
- let http_client = reqwest :: Client :: builder ( ) . build ( ) . unwrap ( ) ;
74
+ pub async fn fetch_ohttp_keys ( ohttp_relay : Url , directory : Url ) -> Result < OhttpKeys , Error > {
75
+ use futures_util :: { AsyncReadExt , AsyncWriteExt } ;
74
76
75
- let ohttp_keys_res = http_client
76
- . get ( format ! ( "{}/ohttp-keys" , directory. as_ref( ) ) )
77
- . send ( )
78
- . await ?
79
- . bytes ( )
80
- . await ?;
81
- Ok ( OhttpKeys :: decode ( ohttp_keys_res. as_ref ( ) ) . map_err ( |_| Error :: OhttpDecodeFailed ) ?)
77
+ let tls_connector = {
78
+ let root_store = futures_rustls:: rustls:: RootCertStore {
79
+ roots : webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) . collect ( ) ,
80
+ } ;
81
+ let config = futures_rustls:: rustls:: ClientConfig :: builder ( )
82
+ . with_root_certificates ( root_store)
83
+ . with_no_client_auth ( ) ;
84
+ futures_rustls:: TlsConnector :: from ( Arc :: new ( config) )
85
+ } ;
86
+ let directory_host = directory. host_str ( ) . ok_or ( Error :: BadDirectoryHost ) ?;
87
+ let domain = futures_rustls:: rustls:: pki_types:: ServerName :: try_from ( directory_host)
88
+ . map_err ( |_| Error :: BadDirectoryHost ) ?
89
+ . to_owned ( ) ;
90
+
91
+ let ws = WebSocket :: open ( & format ! (
92
+ "wss://{}:443" ,
93
+ ohttp_relay. host_str( ) . ok_or( Error :: BadOhttpWsHost ) ?
94
+ ) )
95
+ . map_err ( |_| Error :: BadOhttpWsHost ) ?;
96
+
97
+ let mut tls_stream = tls_connector
98
+ . connect ( domain, ws)
99
+ . await
100
+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
101
+ let ohttp_keys_req = format ! (
102
+ "GET /ohttp-keys HTTP/1.1\r \n Host: {}\r \n Connection: close\r \n \r \n " ,
103
+ directory_host
104
+ ) ;
105
+ tls_stream
106
+ . write_all ( ohttp_keys_req. as_bytes ( ) )
107
+ . await
108
+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
109
+ tls_stream. flush ( ) . await . unwrap ( ) ;
110
+ let mut response_bytes = Vec :: new ( ) ;
111
+ tls_stream. read_to_end ( & mut response_bytes) . await . unwrap ( ) ;
112
+ let ( _headers, res_body) = separate_headers_and_body ( & response_bytes) ?;
113
+ payjoin:: OhttpKeys :: decode ( & res_body) . map_err ( |_| Error :: OhttpDecodeFailed )
114
+ }
115
+
116
+ fn separate_headers_and_body ( response_bytes : & [ u8 ] ) -> Result < ( & [ u8 ] , & [ u8 ] ) , Error > {
117
+ let separator = b"\r \n \r \n " ;
118
+
119
+ // Search for the separator
120
+ if let Some ( position) = response_bytes
121
+ . windows ( separator. len ( ) )
122
+ . position ( |window| window == separator)
123
+ {
124
+ // The body starts immediately after the separator
125
+ let body_start_index = position + separator. len ( ) ;
126
+ let headers = & response_bytes[ ..position] ;
127
+ let body = & response_bytes[ body_start_index..] ;
128
+
129
+ Ok ( ( headers, body) )
130
+ } else {
131
+ Err ( Error :: RequestFailed (
132
+ "No header-body separator found in the response" . to_string ( ) ,
133
+ ) )
134
+ }
82
135
}
83
136
84
137
#[ derive( Debug ) ]
@@ -90,6 +143,9 @@ pub enum Error {
90
143
OhttpDecodeFailed ,
91
144
Shutdown ,
92
145
SessionExpired ,
146
+ BadDirectoryHost ,
147
+ BadOhttpWsHost ,
148
+ RequestFailed ( String ) ,
93
149
}
94
150
95
151
impl std:: error:: Error for Error { }
@@ -104,6 +160,9 @@ impl std::fmt::Display for Error {
104
160
Error :: OhttpDecodeFailed => write ! ( f, "Failed to decode ohttp keys" ) ,
105
161
Error :: Shutdown => write ! ( f, "Payjoin stopped by application shutdown" ) ,
106
162
Error :: SessionExpired => write ! ( f, "Payjoin session expired. Create a new payment request and have the sender try again." ) ,
163
+ Error :: BadDirectoryHost => write ! ( f, "Bad directory host" ) ,
164
+ Error :: BadOhttpWsHost => write ! ( f, "Bad ohttp ws host" ) ,
165
+ Error :: RequestFailed ( e) => write ! ( f, "Request failed: {}" , e) ,
107
166
}
108
167
}
109
168
}
0 commit comments