Skip to content

Commit

Permalink
Rebase + fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Aug 26, 2024
1 parent 46dec47 commit cdd91f6
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
1 change: 1 addition & 0 deletions edb/server/edbrust/edbrust-util/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions edb/server/edbrust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

10 changes: 6 additions & 4 deletions edb/server/pgrust/src/conn_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ pub enum SslVersion {
Tls1_3,
}

impl <'a> TryFrom<Cow<'a, str>> for SslVersion {
impl<'a> TryFrom<Cow<'a, str>> for SslVersion {
type Error = ParseError;
fn try_from(value: Cow<str>) -> Result<SslVersion, Self::Error> {
Ok(match value.as_ref() {
"tls_1" => SslVersion::Tls1,
"tls_1.1" => SslVersion::Tls1_1,
"tls_1.2" => SslVersion::Tls1_2,
"tls_1.3" => SslVersion::Tls1_3,
_ => return Err(ParseError::InvalidTLSVersion(value.to_string()))
_ => return Err(ParseError::InvalidTLSVersion(value.to_string())),
})
}
}
Expand Down Expand Up @@ -685,11 +685,13 @@ pub fn parse_postgres_url(
if ssl_min_protocol_version.is_none() {
ssl_min_protocol_version = env.read("PGSSLMINPROTOCOLVERSION");
}
ssl.min_protocol_version = ssl_min_protocol_version.map(|s| s.try_into()).transpose()?;
ssl.min_protocol_version =
ssl_min_protocol_version.map(|s| s.try_into()).transpose()?;
if ssl_max_protocol_version.is_none() {
ssl_max_protocol_version = env.read("PGSSLMAXPROTOCOLVERSION");
}
ssl.max_protocol_version = ssl_max_protocol_version.map(|s| s.try_into()).transpose()?;
ssl.max_protocol_version =
ssl_max_protocol_version.map(|s| s.try_into()).transpose()?;

// There is no environment variable equivalent to this option
ssl.password = sslpassword.map(|s| s.into_owned());
Expand Down
31 changes: 22 additions & 9 deletions edb/server/pgrust/src/connection/ssl.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
use openssl::{ssl::{SslContextBuilder, SslMethod, SslVerifyMode}, x509::verify::X509VerifyFlags};
use openssl::{
ssl::{SslContextBuilder, SslMethod, SslVerifyMode},
x509::verify::X509VerifyFlags,
};

use crate::conn_string::{SslMode, SslParameters};

/// Given a set of [`SslParameters`], configures an OpenSSL context.
pub fn create_ssl_client_context(mut ssl: SslContextBuilder, ssl_mode: SslMode, parameters: SslParameters) -> Result<SslContextBuilder, Box<dyn std::error::Error>> {
pub fn create_ssl_client_context(
mut ssl: SslContextBuilder,
ssl_mode: SslMode,
parameters: SslParameters,
) -> Result<SslContextBuilder, Box<dyn std::error::Error>> {
let SslParameters {
cert,
key,
Expand All @@ -30,7 +37,8 @@ pub fn create_ssl_client_context(mut ssl: SslContextBuilder, ssl_mode: SslMode,
// Load CRL
if let Some(crl) = &crl {
ssl.set_ca_file(crl)?;
ssl.verify_param_mut().set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
ssl.verify_param_mut()
.set_flags(X509VerifyFlags::CRL_CHECK | X509VerifyFlags::CRL_CHECK_ALL)?;
}
}

Expand All @@ -54,7 +62,7 @@ pub fn create_ssl_client_context(mut ssl: SslContextBuilder, ssl_mode: SslMode,

ssl.set_min_proto_version(min_protocol_version.map(|s| s.into()))?;
ssl.set_max_proto_version(max_protocol_version.map(|s| s.into()))?;

// // Configure key log filename
// if let Some(keylog_filename) = &parameters.keylog_filename {
// context_builder.set_keylog_file(keylog_filename)?;
Expand All @@ -74,10 +82,15 @@ mod tests {
let cert_path = Path::new("../../../tests/certs").canonicalize().unwrap();

let ssl = SslContextBuilder::new(SslMethod::tls()).unwrap();
let ssl = create_ssl_client_context(ssl, SslMode::VerifyFull, SslParameters {
cert: Some(cert_path.join("client.cert.pem")),
key: Some(cert_path.join("client.key.pem")),
..Default::default()
}).unwrap();
let ssl = create_ssl_client_context(
ssl,
SslMode::VerifyFull,
SslParameters {
cert: Some(cert_path.join("client.cert.pem")),
key: Some(cert_path.join("client.key.pem")),
..Default::default()
},
)
.unwrap();
}
}

0 comments on commit cdd91f6

Please sign in to comment.