Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Jan 27, 2025
1 parent 32cc9e7 commit a742c5f
Show file tree
Hide file tree
Showing 12 changed files with 394 additions and 137 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ tracing-subscriber = "0"
rstest = "0.22.0"
test-log = { version = "0", features = ["trace"] }
pyo3 = { workspace = true }
gel-stream.workspace = true

[lib]
12 changes: 5 additions & 7 deletions rust/frontend/examples/smoketest.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{cell::RefCell, collections::HashMap, future::Future, rc::Rc};

use gel_auth::CredentialData;
use gel_stream::client::{Connector, Target, TlsParameters};
use openssl::ssl::{Ssl, SslContext, SslMethod};
use pgrust::{connection::{Client, Credentials}, protocol::edgedb::data::{CommandComplete, ParameterStatus, StateDataDescription}};
use tokio::{
Expand Down Expand Up @@ -34,19 +34,17 @@ impl SmokeTest for PostgresSelect {

async fn run(&self, setup: &TestSetup) -> Result<(), Box<dyn std::error::Error>> {
use pgrust::protocol::postgres::data::{DataRow, ErrorResponse, RowDescription};
let mut ssl = SslContext::builder(SslMethod::tls_client())?.build();
let mut ssl = Ssl::new(&ssl)?;
ssl.set_connect_state();

let socket = TcpSocket::new_v4()?.connect(setup.addr).await?;
let target = Target::new_tcp_tls(setup.addr, TlsParameters::default());
let connector = Connector::new(target)?;

let credentials = Credentials {
username: setup.username.clone(),
password: setup.password.clone(),
database: setup.database.clone(),
server_settings: HashMap::new(),
};
let (client, task) = Client::new(credentials, socket, ssl);
let (client, task) = Client::new(credentials, connector);
tokio::task::spawn_local(task);
client.ready().await?;

Expand Down Expand Up @@ -95,7 +93,7 @@ impl SmokeTest for EdgeQLSelect {
}

async fn run(&self, setup: &TestSetup) -> Result<(), Box<dyn std::error::Error>> {
use pgrust::protocol::edgedb::{data::{Message, ClientHandshake, Data, ServerHandshake}, builder, meta};
use pgrust::protocol::edgedb::{data::Data, builder, meta};

let socket = TcpSocket::new_v4()?.connect(setup.addr).await?;
let mut ssl = SslContext::builder(SslMethod::tls_client())?;
Expand Down
2 changes: 2 additions & 0 deletions rust/gel-stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ __manual_tests = []
derive_more = { version = "1", features = ["full"] }
thiserror = "2"
rustls-pki-types = "1"
futures = "0.3"

tokio = { version = "1", optional = true, features = ["full"] }
rustls = { version = "0.23", optional = true }
openssl = { version = "0.10.55", optional = true }
tokio-openssl = { version = "0.6.5", optional = true }
hickory-resolver = { version = "0.24.2", optional = true }
rustls-tokio-stream = { version = "0.3.0", optional = true }
tokio-rustls = "0.26.0"
rustls-platform-verifier = { version = "0.5.0", optional = true }
webpki = { version = "0.22", optional = true }

Expand Down
4 changes: 2 additions & 2 deletions rust/gel-stream/src/client/connection.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::net::SocketAddr;

use super::stream::UpgradableStream;
use crate::{UpgradableStream, ConnectionError};
use super::target::{MaybeResolvedTarget, ResolvedTarget};
use super::tokio_stream::Resolver;
use super::{ConnectionError, Ssl, Target, TlsInit};
use super::{Ssl, Target, TlsInit};

type Connection = UpgradableStream<super::Stream, Option<super::Ssl>>;

Expand Down
120 changes: 3 additions & 117 deletions rust/gel-stream/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,131 +7,15 @@ pub mod rustls;
#[cfg(feature = "tokio")]
pub mod tokio_stream;

pub mod stream;

mod connection;
pub(crate) mod target;

pub use connection::Connector;
pub use target::{ResolvedTarget, Target, TargetName};

macro_rules! __invalid_state {
($error:literal) => {{
eprintln!(
"Invalid connection state: {}\n{}",
$error,
::std::backtrace::Backtrace::capture()
);
#[allow(deprecated)]
$crate::client::ConnectionError::__InvalidState
}};
}
pub(crate) use __invalid_state as invalid_state;
use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};

#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
/// Invalid state error, suggesting a logic error in code rather than a server or client failure.
/// Use the `invalid_state!` macro instead which will print a backtrace.
#[error("Invalid state")]
#[deprecated = "Use invalid_state!"]
__InvalidState,

/// I/O error encountered during connection operations.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),

/// UTF-8 decoding error.
#[error("UTF8 error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),

/// SSL-related error.
#[error("SSL error: {0}")]
SslError(#[from] SslError),
}

#[derive(Debug, thiserror::Error)]
pub enum SslError {
#[error("SSL is not supported by this client transport")]
SslUnsupportedByClient,

#[cfg(feature = "openssl")]
#[error("OpenSSL error: {0}")]
OpenSslError(#[from] ::openssl::ssl::Error),
#[cfg(feature = "openssl")]
#[error("OpenSSL error: {0}")]
OpenSslErrorStack(#[from] ::openssl::error::ErrorStack),
#[cfg(feature = "openssl")]
#[error("OpenSSL certificate verification error: {0}")]
OpenSslErrorVerify(#[from] ::openssl::x509::X509VerifyResult),

#[cfg(feature = "rustls")]
#[error("Rustls error: {0}")]
RustlsError(#[from] ::rustls::Error),

#[cfg(feature = "rustls")]
#[error("Webpki error: {0}")]
WebpkiError(::webpki::Error),

#[cfg(feature = "rustls")]
#[error("Verifier builder error: {0}")]
VerifierBuilderError(#[from] ::rustls::server::VerifierBuilderError),

#[error("Invalid DNS name: {0}")]
InvalidDnsNameError(#[from] ::rustls_pki_types::InvalidDnsNameError),

#[error("SSL I/O error: {0}")]
SslIoError(#[from] std::io::Error),
}

impl SslError {
/// Returns a common error for any time of crypto backend.
pub fn common_error(&self) -> Option<CommonError> {
match self {
#[cfg(feature = "rustls")]
SslError::RustlsError(::rustls::Error::InvalidCertificate(cert_err)) => {
match cert_err {
::rustls::CertificateError::NotValidForName => {
Some(CommonError::InvalidCertificateForName)
}
::rustls::CertificateError::Revoked => Some(CommonError::CertificateRevoked),
::rustls::CertificateError::Expired => Some(CommonError::CertificateExpired),
::rustls::CertificateError::UnknownIssuer => Some(CommonError::InvalidIssuer),
_ => None,
}
}
#[cfg(feature = "openssl")]
SslError::OpenSslErrorVerify(e) => match e.as_raw() {
openssl_sys::X509_V_ERR_HOSTNAME_MISMATCH => {
Some(CommonError::InvalidCertificateForName)
}
openssl_sys::X509_V_ERR_IP_ADDRESS_MISMATCH => {
Some(CommonError::InvalidCertificateForName)
}
openssl_sys::X509_V_ERR_CERT_REVOKED => Some(CommonError::CertificateRevoked),
openssl_sys::X509_V_ERR_CERT_HAS_EXPIRED => Some(CommonError::CertificateExpired),
openssl_sys::X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT
| openssl_sys::X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY => {
Some(CommonError::InvalidIssuer)
}
_ => None,
},
_ => None,
}
}
}

#[derive(Debug, thiserror::Error, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum CommonError {
#[error("The certificate's subject name(s) do not match the name of the host")]
InvalidCertificateForName,
#[error("The certificate has been revoked")]
CertificateRevoked,
#[error("The certificate has expired")]
CertificateExpired,
#[error("The certificate was issued by an untrusted authority")]
InvalidIssuer,
}
use crate::SslError;

// Note that we choose rustls when both openssl and rustls are enabled.

Expand Down Expand Up @@ -223,6 +107,8 @@ mod tests {

use tokio::io::{AsyncReadExt, AsyncWriteExt};

use crate::{CommonError, ConnectionError};

use super::*;

#[cfg(unix)]
Expand Down
66 changes: 57 additions & 9 deletions rust/gel-stream/src/client/rustls.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::client::WebPkiServerVerifier;
use rustls::{
ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme,
ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, ServerConnection, SignatureScheme
};
use rustls_pki_types::{
CertificateDer, CertificateRevocationListDer, DnsName, ServerName, UnixTime,
};
use rustls_platform_verifier::Verifier;
use rustls_tokio_stream::ServerConfigProvider;
use tokio::net::TcpStream;

use super::stream::{Stream, StreamWithUpgrade};
use crate::{RewindStream, Stream, StreamWithUpgrade};
use super::tokio_stream::TokioStream;
use super::{TlsCert, TlsInit, TlsParameters, TlsServerCertVerify};
use std::any::Any;
Expand All @@ -20,23 +22,23 @@ impl<S: Stream + 'static> StreamWithUpgrade for (S, Option<ClientConnection>) {
type Config = ClientConnection;
type Upgrade = rustls_tokio_stream::TlsStream;

async fn secure_upgrade(self) -> Result<Self::Upgrade, super::SslError>
async fn secure_upgrade(self) -> Result<Self::Upgrade, crate::SslError>
where
Self: Sized,
{
let Some(tls) = self.1 else {
return Err(super::SslError::SslUnsupportedByClient);
return Err(crate::SslError::SslUnsupportedByClient);
};

// Note that we only support Tokio TcpStream for rustls.
let stream = &mut Some(self.0) as &mut dyn Any;
let Some(stream) = stream.downcast_mut::<Option<TokioStream>>() else {
return Err(super::SslError::SslUnsupportedByClient);
return Err(crate::SslError::SslUnsupportedByClient);
};

let stream = stream.take().unwrap();
let TokioStream::Tcp(stream) = stream else {
return Err(super::SslError::SslUnsupportedByClient);
return Err(crate::SslError::SslUnsupportedByClient);
};

let mut stream = rustls_tokio_stream::TlsStream::new_client_side(stream, tls, None);
Expand All @@ -47,7 +49,53 @@ impl<S: Stream + 'static> StreamWithUpgrade for (S, Option<ClientConnection>) {
let kind = e.kind();
if let Some(e2) = e.into_inner() {
match e2.downcast::<::rustls::Error>() {
Ok(e) => return Err(super::SslError::RustlsError(*e)),
Ok(e) => return Err(crate::SslError::RustlsError(*e)),
Err(e) => return Err(std::io::Error::new(kind, e).into()),
}
} else {
return Err(std::io::Error::from(kind).into());
}
}

Ok(stream)
}
}

impl<S: Stream + 'static> StreamWithUpgrade for (S, Option<ServerConfigProvider>) {
type Base = S;
type Config = ClientConnection;
type Upgrade = RewindStream<tokio_rustls::TlsStream<RewindStream<TcpStream>>>;

async fn secure_upgrade(self) -> Result<Self::Upgrade, crate::SslError>
where
Self: Sized,
{
let Some(tls) = self.1 else {
return Err(crate::SslError::SslUnsupportedByClient);
};

// Note that we only support Tokio TcpStream for rustls.
let stream = &mut Some(self.0) as &mut dyn Any;
let Some(stream) = stream.downcast_mut::<Option<TokioStream>>() else {
return Err(crate::SslError::SslUnsupportedByClient);
};

let stream = stream.take().unwrap();
let TokioStream::Tcp(stream) = stream else {
return Err(crate::SslError::SslUnsupportedByClient);
};

let mut stream = rustls_tokio_stream::TlsStream::new_server_side_acceptor(stream, tls, None);
let res = stream.handshake().await;
let (tcp, conn) = stream.into_inner().await?;


// Potentially unwrap the error to get the underlying error.
if let Err(e) = res {
let kind = e.kind();
if let Some(e2) = e.into_inner() {
match e2.downcast::<::rustls::Error>() {
Ok(e) => return Err(crate::SslError::RustlsError(*e)),
Err(e) => return Err(std::io::Error::new(kind, e).into()),
}
} else {
Expand All @@ -63,7 +111,7 @@ fn make_verifier(
server_cert_verify: &TlsServerCertVerify,
root_cert: &TlsCert,
crls: Vec<CertificateRevocationListDer<'static>>,
) -> Result<Arc<dyn ServerCertVerifier>, super::SslError> {
) -> Result<Arc<dyn ServerCertVerifier>, crate::SslError> {
if *server_cert_verify == TlsServerCertVerify::Insecure {
return Ok(Arc::new(NullVerifier));
}
Expand Down Expand Up @@ -103,7 +151,7 @@ impl TlsInit for ClientConnection {
fn init(
parameters: &TlsParameters,
name: Option<ServerName>,
) -> Result<Self::Tls, super::SslError> {
) -> Result<Self::Tls, crate::SslError> {
let _ = ::rustls::crypto::ring::default_provider().install_default();

let TlsParameters {
Expand Down
1 change: 1 addition & 0 deletions rust/gel-stream/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod stream;
Loading

0 comments on commit a742c5f

Please sign in to comment.