|
1 | 1 | use std::collections::HashMap;
|
| 2 | +use std::sync::Arc; |
2 | 3 |
|
3 | 4 | use crate::error::MutinyError;
|
4 | 5 | use crate::storage::MutinyStorage;
|
5 | 6 | use bitcoin::Transaction;
|
6 | 7 | use core::time::Duration;
|
| 8 | +use gloo_net::websocket::futures::WebSocket; |
7 | 9 | use hex_conservative::DisplayHex;
|
8 | 10 | use once_cell::sync::Lazy;
|
9 | 11 | use payjoin::receive::v2::Enrolled;
|
@@ -76,16 +78,73 @@ impl<S: MutinyStorage> PayjoinStorage for S {
|
76 | 78 | }
|
77 | 79 | }
|
78 | 80 |
|
79 |
| -pub async fn fetch_ohttp_keys(_ohttp_relay: Url, directory: Url) -> Result<OhttpKeys, Error> { |
80 |
| - let http_client = reqwest::Client::builder().build()?; |
| 81 | +pub async fn fetch_ohttp_keys(ohttp_relay: Url, directory: Url) -> Result<OhttpKeys, Error> { |
| 82 | + use futures_util::{AsyncReadExt, AsyncWriteExt}; |
81 | 83 |
|
82 |
| - let ohttp_keys_res = http_client |
83 |
| - .get(format!("{}/ohttp-keys", directory.as_ref())) |
84 |
| - .send() |
85 |
| - .await? |
86 |
| - .bytes() |
87 |
| - .await?; |
88 |
| - Ok(OhttpKeys::decode(ohttp_keys_res.as_ref()).map_err(|_| Error::OhttpDecodeFailed)?) |
| 84 | + let tls_connector = { |
| 85 | + let root_store = futures_rustls::rustls::RootCertStore { |
| 86 | + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), |
| 87 | + }; |
| 88 | + let config = futures_rustls::rustls::ClientConfig::builder() |
| 89 | + .with_root_certificates(root_store) |
| 90 | + .with_no_client_auth(); |
| 91 | + futures_rustls::TlsConnector::from(Arc::new(config)) |
| 92 | + }; |
| 93 | + let directory_host = directory.host_str().ok_or(Error::BadDirectoryHost)?; |
| 94 | + let domain = futures_rustls::rustls::pki_types::ServerName::try_from(directory_host) |
| 95 | + .map_err(|_| Error::BadDirectoryHost)? |
| 96 | + .to_owned(); |
| 97 | + |
| 98 | + let ws = WebSocket::open(&format!( |
| 99 | + "wss://{}:443", |
| 100 | + ohttp_relay.host_str().ok_or(Error::BadOhttpWsHost)? |
| 101 | + )) |
| 102 | + .map_err(|_| Error::BadOhttpWsHost)?; |
| 103 | + |
| 104 | + let mut tls_stream = tls_connector |
| 105 | + .connect(domain, ws) |
| 106 | + .await |
| 107 | + .map_err(|e| Error::RequestFailed(e.to_string()))?; |
| 108 | + let ohttp_keys_req = format!( |
| 109 | + "GET /ohttp-keys HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", |
| 110 | + directory_host |
| 111 | + ); |
| 112 | + tls_stream |
| 113 | + .write_all(ohttp_keys_req.as_bytes()) |
| 114 | + .await |
| 115 | + .map_err(|e| Error::RequestFailed(e.to_string()))?; |
| 116 | + tls_stream |
| 117 | + .flush() |
| 118 | + .await |
| 119 | + .map_err(|e| Error::RequestFailed(e.to_string()))?; |
| 120 | + let mut response_bytes = Vec::new(); |
| 121 | + tls_stream |
| 122 | + .read_to_end(&mut response_bytes) |
| 123 | + .await |
| 124 | + .map_err(|e| Error::RequestFailed(e.to_string()))?; |
| 125 | + let (_headers, res_body) = separate_headers_and_body(&response_bytes)?; |
| 126 | + payjoin::OhttpKeys::decode(res_body).map_err(|_| Error::OhttpDecodeFailed) |
| 127 | +} |
| 128 | + |
| 129 | +fn separate_headers_and_body(response_bytes: &[u8]) -> Result<(&[u8], &[u8]), Error> { |
| 130 | + let separator = b"\r\n\r\n"; |
| 131 | + |
| 132 | + // Search for the separator |
| 133 | + if let Some(position) = response_bytes |
| 134 | + .windows(separator.len()) |
| 135 | + .position(|window| window == separator) |
| 136 | + { |
| 137 | + // The body starts immediately after the separator |
| 138 | + let body_start_index = position + separator.len(); |
| 139 | + let headers = &response_bytes[..position]; |
| 140 | + let body = &response_bytes[body_start_index..]; |
| 141 | + |
| 142 | + Ok((headers, body)) |
| 143 | + } else { |
| 144 | + Err(Error::RequestFailed( |
| 145 | + "No header-body separator found in the response".to_string(), |
| 146 | + )) |
| 147 | + } |
89 | 148 | }
|
90 | 149 |
|
91 | 150 | #[derive(Debug)]
|
|
0 commit comments