Skip to content

Commit

Permalink
Plugged Client Certificate
Browse files Browse the repository at this point in the history
  • Loading branch information
amigin committed Nov 21, 2023
1 parent 87d99d8 commit 6b5f7ae
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
5 changes: 3 additions & 2 deletions src/clients_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};

use tokio::sync::RwLock;

use crate::{FlUrlError, HttpClient, UrlBuilder};
use crate::{ClientCertificate, FlUrlError, HttpClient, UrlBuilder};

pub struct ClientsCache {
pub clients: RwLock<HashMap<String, Arc<HttpClient>>>,
Expand All @@ -19,6 +19,7 @@ impl ClientsCache {
&self,
url_builder: &UrlBuilder,
request_timeout: Duration,
client_certificate: Option<ClientCertificate>,
) -> Result<Arc<HttpClient>, FlUrlError> {
let schema_and_domain = url_builder.get_scheme_and_host();
{
Expand All @@ -40,7 +41,7 @@ impl ClientsCache {
.unwrap());
}

let new_one = HttpClient::new(url_builder, request_timeout).await?;
let new_one = HttpClient::new(url_builder, client_certificate, request_timeout).await?;
let new_one = Arc::new(new_one);

write_access.insert(schema_and_domain.to_string(), new_one.clone());
Expand Down
1 change: 1 addition & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub enum FlUrlError {
ConnectionIsDead,
InvalidHttp1HandShake(String),
CanNotEstablishConnection(String),
ClientCertificateError(tokio_rustls::rustls::Error),
#[cfg(feature = "support-unix-socket")]
UnixSocketError(unix_sockets::FlUrlUnixSocketError),
}
Expand Down
6 changes: 4 additions & 2 deletions src/fl_url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,14 @@ impl FlUrl {
let scheme_and_host = self.url.get_scheme_and_host();

let result = if self.do_not_reuse_connection {
let client = HttpClient::new(&self.url, self.execute_timeout).await?;
let client = HttpClient::new(&self.url, self.client_cert, self.execute_timeout).await?;
client
.execute_request(&self.url, method, &self.headers, body, self.execute_timeout)
.await
} else {
let client = CLIENTS_CACHED.get(&self.url, self.execute_timeout).await?;
let client = CLIENTS_CACHED
.get(&self.url, self.execute_timeout, self.client_cert)
.await?;
client
.execute_request(&self.url, method, &self.headers, body, self.execute_timeout)
.await
Expand Down
19 changes: 16 additions & 3 deletions src/http_client/connect_to_tls_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ use tokio::net::TcpStream;

use tokio_rustls::{rustls, TlsConnector};

use crate::FlUrlError;
use crate::{ClientCertificate, FlUrlError};

use super::cert_content::ROOT_CERT_STORE;

pub async fn connect_to_tls_endpoint(
host_port: &str,
domain: &str,
request_timeout: Duration,
client_certificate: Option<ClientCertificate>,
) -> Result<SendRequest<Full<Bytes>>, FlUrlError> {
loop {
let connect = TcpStream::connect(host_port);
Expand All @@ -32,8 +33,20 @@ pub async fn connect_to_tls_endpoint(
Ok(tcp_stream) => {
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(ROOT_CERT_STORE.clone())
.with_no_client_auth();
.with_root_certificates(ROOT_CERT_STORE.clone());

let config = if let Some(client_cert) = client_certificate {
let result =
config.with_client_auth_cert(vec![client_cert.cert], client_cert.pkey);

match result {
Ok(config) => config,
Err(err) => return Err(FlUrlError::ClientCertificateError(err)),
}
} else {
config.with_no_client_auth()
};

let connector = TlsConnector::from(Arc::new(config));

let domain = rustls::ServerName::try_from(domain).unwrap();
Expand Down
20 changes: 15 additions & 5 deletions src/http_client/http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@ use hyper::{client::conn::http1::SendRequest, Method, Request, Uri};
use rust_extensions::StrOrString;
use tokio::sync::Mutex;

use crate::{FlUrlError, FlUrlResponse, UrlBuilder};
use crate::{ClientCertificate, FlUrlError, FlUrlResponse, UrlBuilder};

pub struct HttpClient {
connection: Mutex<Option<SendRequest<Full<Bytes>>>>,
host: String,
}

impl HttpClient {
pub async fn new(src: &UrlBuilder, request_timeout: Duration) -> Result<Self, FlUrlError> {
pub async fn new(
src: &UrlBuilder,
client_certificate: Option<ClientCertificate>,
request_timeout: Duration,
) -> Result<Self, FlUrlError> {
let host_port = src.get_host_port();

let domain = src.get_domain();
Expand All @@ -33,7 +37,13 @@ impl HttpClient {
};

let connection = if is_https {
super::connect_to_tls_endpoint(host_port.as_str(), domain, request_timeout).await?
super::connect_to_tls_endpoint(
host_port.as_str(),
domain,
request_timeout,
client_certificate,
)
.await?
} else {
super::connect_to_http_endpoint(host_port.as_str(), request_timeout).await?
};
Expand Down Expand Up @@ -127,7 +137,7 @@ mod tests {
async fn test_http_request() {
let url_builder = UrlBuilder::new("http://google.com/".into());

let fl_url_client = HttpClient::new(&url_builder, REQUEST_TIMEOUT)
let fl_url_client = HttpClient::new(&url_builder, None, REQUEST_TIMEOUT)
.await
.unwrap();

Expand Down Expand Up @@ -175,7 +185,7 @@ mod tests {
async fn test_https_request() {
let url_builder = UrlBuilder::new("https://trade-demo.yourfin.tech".into());

let fl_url_client = HttpClient::new(&url_builder, REQUEST_TIMEOUT)
let fl_url_client = HttpClient::new(&url_builder, None, REQUEST_TIMEOUT)
.await
.unwrap();

Expand Down

0 comments on commit 6b5f7ae

Please sign in to comment.