From e55610be05f922d1017f6c56a8ed112c64cbfd44 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 26 Nov 2024 11:18:51 -0700 Subject: [PATCH 1/6] . --- Cargo.lock | 1 + rust/pgrust/Cargo.toml | 5 +- rust/pgrust/examples/connect.rs | 2 +- rust/pgrust/examples/proxy.rs | 2 + rust/pgrust/src/auth/mod.rs | 69 ++ rust/pgrust/src/auth/scram.rs | 10 +- rust/pgrust/src/connection/conn.rs | 11 +- rust/pgrust/src/connection/mod.rs | 10 +- rust/pgrust/src/connection/raw_conn.rs | 18 +- rust/pgrust/src/errors/edgedb.rs | 106 +++ rust/pgrust/src/errors/mod.rs | 4 +- .../src/handshake/client_state_machine.rs | 22 +- rust/pgrust/src/handshake/edgedb_server.rs | 378 +++++++++ rust/pgrust/src/handshake/mod.rs | 72 +- rust/pgrust/src/handshake/server_auth.rs | 179 +++++ .../src/handshake/server_state_machine.rs | 482 ++++++------ rust/pgrust/src/protocol/buffer.rs | 11 +- rust/pgrust/src/protocol/datatypes.rs | 180 ++++- rust/pgrust/src/protocol/edgedb.rs | 503 ++++++++++++ rust/pgrust/src/protocol/gen.rs | 59 +- rust/pgrust/src/protocol/message_group.rs | 23 +- rust/pgrust/src/protocol/mod.rs | 37 +- rust/pgrust/src/protocol/postgres.rs | 739 ++++++++++++++++++ rust/pgrust/src/python.rs | 11 +- rust/pgrust/tests/real_postgres.rs | 6 +- 25 files changed, 2539 insertions(+), 401 deletions(-) create mode 100644 rust/pgrust/examples/proxy.rs create mode 100644 rust/pgrust/src/errors/edgedb.rs create mode 100644 rust/pgrust/src/handshake/edgedb_server.rs create mode 100644 rust/pgrust/src/handshake/server_auth.rs create mode 100644 rust/pgrust/src/protocol/edgedb.rs create mode 100644 rust/pgrust/src/protocol/postgres.rs diff --git a/Cargo.lock b/Cargo.lock index 28cc099a643..a04671812a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1590,6 +1590,7 @@ dependencies = [ "tracing-subscriber", "unicode-normalization", "url", + "uuid", ] [[package]] diff --git a/rust/pgrust/Cargo.toml b/rust/pgrust/Cargo.toml index c23310dd1ae..a668bdda775 100644 --- a/rust/pgrust/Cargo.toml +++ b/rust/pgrust/Cargo.toml @@ -15,12 +15,13 @@ optimizer = [] [dependencies] pyo3.workspace = true tokio.workspace = true -tracing.workspace = true futures = "0" scopeguard = "1" itertools = "0" thiserror = "1" +tracing = "0" +tracing-subscriber = "0" strum = { version = "0.26", features = ["derive"] } consume_on_drop = "0" smart-default = "0" @@ -43,6 +44,7 @@ serde-pickle = "1" percent-encoding = "2" roaring = "0.10.6" constant_time_eq = "0.3" +uuid = "1" [dependencies.derive_more] version = "1.0.0-beta.6" @@ -62,7 +64,6 @@ hex-literal = "0.4" tempfile = "3" socket2 = "0.5.7" libc = "0.2.158" -tracing-subscriber = "0" [dev-dependencies.tokio] version = "1" diff --git a/rust/pgrust/examples/connect.rs b/rust/pgrust/examples/connect.rs index e7e8c028999..bb26dafddcc 100644 --- a/rust/pgrust/examples/connect.rs +++ b/rust/pgrust/examples/connect.rs @@ -3,7 +3,7 @@ use clap_derive::Parser; use openssl::ssl::{Ssl, SslContext, SslMethod}; use pgrust::{ connection::{dsn::parse_postgres_dsn_env, Client, Credentials, ResolvedTarget}, - protocol::{DataRow, ErrorResponse, RowDescription}, + protocol::postgres::data::{DataRow, ErrorResponse, RowDescription}, }; use std::net::SocketAddr; use tokio::task::LocalSet; diff --git a/rust/pgrust/examples/proxy.rs b/rust/pgrust/examples/proxy.rs new file mode 100644 index 00000000000..7f755fb76d4 --- /dev/null +++ b/rust/pgrust/examples/proxy.rs @@ -0,0 +1,2 @@ +#[tokio::main] +async fn main() {} diff --git a/rust/pgrust/src/auth/mod.rs b/rust/pgrust/src/auth/mod.rs index 7a90bb0f2eb..2a08588e803 100644 --- a/rust/pgrust/src/auth/mod.rs +++ b/rust/pgrust/src/auth/mod.rs @@ -4,8 +4,77 @@ mod stringprep; mod stringprep_table; pub use md5::{md5_password, StoredHash}; +use rand::Rng; pub use scram::{ generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError, ServerEnvironment, ServerTransaction, Sha256Out, StoredKey, }; pub use stringprep::{sasl_normalize_password, sasl_normalize_password_bytes}; + +/// Specifies the type of authentication or indicates the authentication method used for a connection. +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] +pub enum AuthType { + /// Denies a login or indicates that a connection was denied. + /// + /// When used with the server, this will cause it to emulate the given + /// authentication type, but unconditionally return a failure. + /// + /// This is used for testing purposes, and to emulate timing when a user + /// does not exist. + #[default] + Deny, + /// Trusts a login without requiring authentication, or indicates + /// that a connection required no authentication. + /// + /// When used with the server side of the handshake, this will cause it to + /// emulate the given authentication type, but unconditionally succeed. + /// Not compatible with SCRAM-SHA-256 as that protocol requires server and client + /// to cryptographically agree on a password. + Trust, + /// Plain text authentication, or indicates that plain text authentication was required. + Plain, + /// MD5 password authentication, or indicates that MD5 password authentication was required. + Md5, + /// SCRAM-SHA-256 authentication, or indicates that SCRAM-SHA-256 authentication was required. + ScramSha256, +} + +#[derive(Debug, Clone)] +pub enum CredentialData { + /// A credential that always succeeds, regardless of input password. Due to + /// the design of SCRAM-SHA-256, this cannot be used with that auth type. + Trust, + /// A credential that always fails, regardless of the input password. + Deny, + /// A plain-text password. + Plain(String), + /// A stored MD5 hash + salt. + Md5(StoredHash), + /// A stored SCRAM-SHA-256 key. + Scram(StoredKey), +} + +impl CredentialData { + pub fn new(ty: AuthType, username: String, password: String) -> Self { + match ty { + AuthType::Deny => Self::Deny, + AuthType::Trust => Self::Trust, + AuthType::Plain => Self::Plain(password), + AuthType::Md5 => Self::Md5(StoredHash::generate(password.as_bytes(), &username)), + AuthType::ScramSha256 => { + let salt: [u8; 32] = rand::thread_rng().gen(); + Self::Scram(StoredKey::generate(password.as_bytes(), &salt, 4096)) + } + } + } + + pub fn auth_type(&self) -> AuthType { + match self { + CredentialData::Trust => AuthType::Trust, + CredentialData::Deny => AuthType::Deny, + CredentialData::Plain(..) => AuthType::Plain, + CredentialData::Md5(..) => AuthType::Md5, + CredentialData::Scram(..) => AuthType::ScramSha256, + } + } +} diff --git a/rust/pgrust/src/auth/scram.rs b/rust/pgrust/src/auth/scram.rs index 63a0533584d..668c74cbe20 100644 --- a/rust/pgrust/src/auth/scram.rs +++ b/rust/pgrust/src/auth/scram.rs @@ -113,7 +113,7 @@ impl ServerTransaction { &mut self, message: &[u8], env: &impl ServerEnvironment, - ) -> Result>, SCRAMError> { + ) -> Result, SCRAMError> { match &self.state { ServerState::Success => Err(SCRAMError::ProtocolError), ServerState::Initial => { @@ -134,7 +134,7 @@ impl ServerTransaction { }; self.state = ServerState::SentChallenge(message.to_owned_bare(), response.to_owned()); - Ok(Some(response.encode().into_bytes())) + Ok(response.encode().into_bytes()) } ServerState::SentChallenge(first_message, first_response) => { let message = ClientFinalMessage::decode(message)?; @@ -174,7 +174,7 @@ impl ServerTransaction { self.state = ServerState::Success; let verifier = BASE64_STANDARD.encode(server_signature).into(); - Ok(Some(ServerFinalResponse { verifier }.encode().into_bytes())) + Ok(ServerFinalResponse { verifier }.encode().into_bytes()) } } } @@ -1015,14 +1015,14 @@ mod tests { String::from_utf8(message.clone()).unwrap(), "n,,n=username,r=<<>>" ); - let message = server.process_message(&message, &env).unwrap().unwrap(); + let message = server.process_message(&message, &env).unwrap(); assert_eq!( String::from_utf8(message.clone()).unwrap(), "r=<<>><<>>,s=aGVsbG8=,i=4096" ); let message = client.process_message(&message, &env).unwrap().unwrap(); assert_eq!(String::from_utf8(message.clone()).unwrap(), "c=biws,r=<<>><<>>,p=621h6u6V3axb7mNYHNgTspTZ3SqILcxuJOsFu5wMjV8="); - let message = server.process_message(&message, &env).unwrap().unwrap(); + let message = server.process_message(&message, &env).unwrap(); assert_eq!( String::from_utf8(message.clone()).unwrap(), "v=moj4kNnZKB3wjXZeQsKYI9luTTakwgH8r0NdGOjugRY=" diff --git a/rust/pgrust/src/connection/conn.rs b/rust/pgrust/src/connection/conn.rs index 6784aefd819..d26f1f69530 100644 --- a/rust/pgrust/src/connection/conn.rs +++ b/rust/pgrust/src/connection/conn.rs @@ -8,8 +8,15 @@ use crate::{ connection::ConnectionError, handshake::ConnectionSslRequirement, protocol::{ - builder, match_message, meta, CommandComplete, DataRow, ErrorResponse, Message, - ReadyForQuery, RowDescription, StructBuffer, + match_message, + postgres::{ + builder, + data::{ + CommandComplete, DataRow, ErrorResponse, Message, ReadyForQuery, RowDescription, + }, + meta, + }, + StructBuffer, }, }; use futures::FutureExt; diff --git a/rust/pgrust/src/connection/mod.rs b/rust/pgrust/src/connection/mod.rs index 2bcf5f7ef7e..aaac05e07c8 100644 --- a/rust/pgrust/src/connection/mod.rs +++ b/rust/pgrust/src/connection/mod.rs @@ -1,6 +1,10 @@ use std::collections::HashMap; -use crate::{auth, errors::PgServerError, protocol::ParseError}; +use crate::{ + auth, + errors::{edgedb::EdbError, PgServerError}, + protocol::ParseError, +}; mod conn; pub mod dsn; @@ -38,6 +42,10 @@ pub enum ConnectionError { #[error("Server error: {0}")] ServerError(#[from] PgServerError), + /// Error returned by the server. + #[error("Server error: {0}")] + EdbServerError(#[from] EdbError), + /// The server sent something we didn't expect #[error("Unexpected server response: {0}")] UnexpectedResponse(String), diff --git a/rust/pgrust/src/connection/raw_conn.rs b/rust/pgrust/src/connection/raw_conn.rs index e3a72fc29b0..a19c95216e5 100644 --- a/rust/pgrust/src/connection/raw_conn.rs +++ b/rust/pgrust/src/connection/raw_conn.rs @@ -7,9 +7,13 @@ use crate::handshake::{ ConnectionDrive, ConnectionState, ConnectionStateSend, ConnectionStateType, ConnectionStateUpdate, }, - AuthType, ConnectionSslRequirement, + ConnectionSslRequirement, +}; +use crate::protocol::{postgres::data::SSLResponse, postgres::meta, StructBuffer}; +use crate::{ + auth::AuthType, + protocol::postgres::{FrontendBuilder, InitialBuilder}, }; -use crate::protocol::{meta, SSLResponse, StructBuffer}; use std::collections::HashMap; use std::pin::Pin; use std::task::{Context, Poll}; @@ -32,17 +36,11 @@ pub struct ConnectionDriver { } impl ConnectionStateSend for ConnectionDriver { - fn send_initial( - &mut self, - message: crate::protocol::definition::InitialBuilder, - ) -> Result<(), std::io::Error> { + fn send_initial(&mut self, message: InitialBuilder) -> Result<(), std::io::Error> { self.send_buffer.extend(message.to_vec()); Ok(()) } - fn send( - &mut self, - message: crate::protocol::definition::FrontendBuilder, - ) -> Result<(), std::io::Error> { + fn send(&mut self, message: FrontendBuilder) -> Result<(), std::io::Error> { self.send_buffer.extend(message.to_vec()); Ok(()) } diff --git a/rust/pgrust/src/errors/edgedb.rs b/rust/pgrust/src/errors/edgedb.rs new file mode 100644 index 00000000000..494eebb0ba3 --- /dev/null +++ b/rust/pgrust/src/errors/edgedb.rs @@ -0,0 +1,106 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display)] +#[repr(i32)] +pub enum EdbError { + InternalServerError = 0x_01_00_00_00, + UnsupportedFeatureError = 0x_02_00_00_00, + ProtocolError = 0x_03_00_00_00, + BinaryProtocolError = 0x_03_01_00_00, + UnsupportedProtocolVersionError = 0x_03_01_00_01, + TypeSpecNotFoundError = 0x_03_01_00_02, + UnexpectedMessageError = 0x_03_01_00_03, + InputDataError = 0x_03_02_00_00, + ParameterTypeMismatchError = 0x_03_02_01_00, + StateMismatchError = 0x_03_02_02_00, + ResultCardinalityMismatchError = 0x_03_03_00_00, + CapabilityError = 0x_03_04_00_00, + UnsupportedCapabilityError = 0x_03_04_01_00, + DisabledCapabilityError = 0x_03_04_02_00, + QueryError = 0x_04_00_00_00, + InvalidSyntaxError = 0x_04_01_00_00, + EdgeQLSyntaxError = 0x_04_01_01_00, + SchemaSyntaxError = 0x_04_01_02_00, + GraphQLSyntaxError = 0x_04_01_03_00, + InvalidTypeError = 0x_04_02_00_00, + InvalidTargetError = 0x_04_02_01_00, + InvalidLinkTargetError = 0x_04_02_01_01, + InvalidPropertyTargetError = 0x_04_02_01_02, + InvalidReferenceError = 0x_04_03_00_00, + UnknownModuleError = 0x_04_03_00_01, + UnknownLinkError = 0x_04_03_00_02, + UnknownPropertyError = 0x_04_03_00_03, + UnknownUserError = 0x_04_03_00_04, + UnknownDatabaseError = 0x_04_03_00_05, + UnknownParameterError = 0x_04_03_00_06, + SchemaError = 0x_04_04_00_00, + SchemaDefinitionError = 0x_04_05_00_00, + InvalidDefinitionError = 0x_04_05_01_00, + InvalidModuleDefinitionError = 0x_04_05_01_01, + InvalidLinkDefinitionError = 0x_04_05_01_02, + InvalidPropertyDefinitionError = 0x_04_05_01_03, + InvalidUserDefinitionError = 0x_04_05_01_04, + InvalidDatabaseDefinitionError = 0x_04_05_01_05, + InvalidOperatorDefinitionError = 0x_04_05_01_06, + InvalidAliasDefinitionError = 0x_04_05_01_07, + InvalidFunctionDefinitionError = 0x_04_05_01_08, + InvalidConstraintDefinitionError = 0x_04_05_01_09, + InvalidCastDefinitionError = 0x_04_05_01_0A, + DuplicateDefinitionError = 0x_04_05_02_00, + DuplicateModuleDefinitionError = 0x_04_05_02_01, + DuplicateLinkDefinitionError = 0x_04_05_02_02, + DuplicatePropertyDefinitionError = 0x_04_05_02_03, + DuplicateUserDefinitionError = 0x_04_05_02_04, + DuplicateDatabaseDefinitionError = 0x_04_05_02_05, + DuplicateOperatorDefinitionError = 0x_04_05_02_06, + DuplicateViewDefinitionError = 0x_04_05_02_07, + DuplicateFunctionDefinitionError = 0x_04_05_02_08, + DuplicateConstraintDefinitionError = 0x_04_05_02_09, + DuplicateCastDefinitionError = 0x_04_05_02_0A, + DuplicateMigrationError = 0x_04_05_02_0B, + SessionTimeoutError = 0x_04_06_00_00, + IdleSessionTimeoutError = 0x_04_06_01_00, + QueryTimeoutError = 0x_04_06_02_00, + TransactionTimeoutError = 0x_04_06_0A_00, + IdleTransactionTimeoutError = 0x_04_06_0A_01, + ExecutionError = 0x_05_00_00_00, + InvalidValueError = 0x_05_01_00_00, + DivisionByZeroError = 0x_05_01_00_01, + NumericOutOfRangeError = 0x_05_01_00_02, + AccessPolicyError = 0x_05_01_00_03, + QueryAssertionError = 0x_05_01_00_04, + IntegrityError = 0x_05_02_00_00, + ConstraintViolationError = 0x_05_02_00_01, + CardinalityViolationError = 0x_05_02_00_02, + MissingRequiredError = 0x_05_02_00_03, + TransactionError = 0x_05_03_00_00, + TransactionConflictError = 0x_05_03_01_00, + TransactionSerializationError = 0x_05_03_01_01, + TransactionDeadlockError = 0x_05_03_01_02, + WatchError = 0x_05_04_00_00, + ConfigurationError = 0x_06_00_00_00, + AccessError = 0x_07_00_00_00, + AuthenticationError = 0x_07_01_00_00, + AvailabilityError = 0x_08_00_00_00, + BackendUnavailableError = 0x_08_00_00_01, + ServerOfflineError = 0x_08_00_00_02, + UnknownTenantError = 0x_08_00_00_03, + ServerBlockedError = 0x_08_00_00_04, + BackendError = 0x_09_00_00_00, + UnsupportedBackendFeatureError = 0x_09_00_01_00, + LogMessage = 0x_F0_00_00_00_u32 as i32, + WarningMessage = 0x_F0_01_00_00_u32 as i32, + ClientError = 0x_FF_00_00_00_u32 as i32, + ClientConnectionError = 0x_FF_01_00_00_u32 as i32, + ClientConnectionFailedError = 0x_FF_01_01_00_u32 as i32, + ClientConnectionFailedTemporarilyError = 0x_FF_01_01_01_u32 as i32, + ClientConnectionTimeoutError = 0x_FF_01_02_00_u32 as i32, + ClientConnectionClosedError = 0x_FF_01_03_00_u32 as i32, + InterfaceError = 0x_FF_02_00_00_u32 as i32, + QueryArgumentError = 0x_FF_02_01_00_u32 as i32, + MissingArgumentError = 0x_FF_02_01_01_u32 as i32, + UnknownArgumentError = 0x_FF_02_01_02_u32 as i32, + InvalidArgumentError = 0x_FF_02_01_03_u32 as i32, + NoDataError = 0x_FF_03_00_00_u32 as i32, + InternalClientError = 0x_FF_04_00_00_u32 as i32, +} + +impl std::error::Error for EdbError {} diff --git a/rust/pgrust/src/errors/mod.rs b/rust/pgrust/src/errors/mod.rs index e8871cd5b8c..39e7e538e28 100644 --- a/rust/pgrust/src/errors/mod.rs +++ b/rust/pgrust/src/errors/mod.rs @@ -2,7 +2,9 @@ use core::str; use paste::paste; use std::{collections::HashMap, str::FromStr}; -use crate::protocol::ErrorResponse; +pub mod edgedb; + +use crate::protocol::postgres::data::ErrorResponse; #[macro_export] macro_rules! pg_error_class { diff --git a/rust/pgrust/src/handshake/client_state_machine.rs b/rust/pgrust/src/handshake/client_state_machine.rs index 77e9250a978..32524cac5de 100644 --- a/rust/pgrust/src/handshake/client_state_machine.rs +++ b/rust/pgrust/src/handshake/client_state_machine.rs @@ -1,15 +1,20 @@ -use super::{AuthType, ConnectionSslRequirement}; +use super::ConnectionSslRequirement; use crate::{ - auth::{self, generate_salted_password, ClientEnvironment, ClientTransaction, Sha256Out}, + auth::{ + self, generate_salted_password, AuthType, ClientEnvironment, ClientTransaction, Sha256Out, + }, connection::{invalid_state, ConnectionError, Credentials, SslError}, errors::PgServerError, protocol::{ - builder, - definition::{FrontendBuilder, InitialBuilder}, - match_message, AuthenticationCleartextPassword, AuthenticationMD5Password, - AuthenticationMessage, AuthenticationOk, AuthenticationSASL, AuthenticationSASLContinue, - AuthenticationSASLFinal, BackendKeyData, ErrorResponse, Message, ParameterStatus, - ParseError, ReadyForQuery, SSLResponse, + match_message, + postgres::data::{ + AuthenticationCleartextPassword, AuthenticationMD5Password, AuthenticationMessage, + AuthenticationOk, AuthenticationSASL, AuthenticationSASLContinue, + AuthenticationSASLFinal, BackendKeyData, ErrorResponse, Message, ParameterStatus, + ReadyForQuery, SSLResponse, + }, + postgres::{builder, FrontendBuilder, InitialBuilder}, + ParseError, }, }; use base64::Engine; @@ -114,6 +119,7 @@ pub trait ConnectionStateUpdate: ConnectionStateSend { /// /// The state machine for a Postgres connection. The state machine is driven /// with calls to [`Self::drive`]. +#[derive(Debug)] pub struct ConnectionState(ConnectionStateImpl); impl ConnectionState { diff --git a/rust/pgrust/src/handshake/edgedb_server.rs b/rust/pgrust/src/handshake/edgedb_server.rs new file mode 100644 index 00000000000..2c325129aa9 --- /dev/null +++ b/rust/pgrust/src/handshake/edgedb_server.rs @@ -0,0 +1,378 @@ +use super::server_auth::{ServerAuth, ServerAuthError}; +use crate::{ + auth::{AuthType, CredentialData}, + connection::ConnectionError, + errors::edgedb::EdbError, + handshake::server_auth::{ServerAuthDrive, ServerAuthResponse}, + protocol::{ + edgedb::{data::*, *}, + match_message, ParseError, StructBuffer, + }, +}; +use std::str::Utf8Error; +use tracing::{error, trace, warn}; + +#[derive(Clone, Copy, Debug)] +pub enum ConnectionStateType { + Connecting, + Authenticating, + Synchronizing, + Ready, +} + +#[derive(Debug)] +pub enum ConnectionDrive<'a> { + RawMessage(&'a [u8]), + Message(Result, ParseError>), + AuthInfo(AuthType, CredentialData), + Parameter(String, String), + Ready([u8; 32]), + Fail(EdbError, &'a str), +} + +pub trait ConnectionStateSend { + fn send(&mut self, message: EdgeDBBackendBuilder) -> Result<(), std::io::Error>; + fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error>; + fn params(&mut self) -> Result<(), std::io::Error>; +} + +pub trait ConnectionStateUpdate: ConnectionStateSend { + fn parameter(&mut self, name: &str, value: &str) {} + fn state_changed(&mut self, state: ConnectionStateType) {} + fn server_error(&mut self, error: &EdbError) {} +} + +#[derive(Debug)] +pub enum ConnectionEvent<'a> { + Send(EdgeDBBackendBuilder<'a>), + Auth(String, String), + Params, + Parameter(&'a str, &'a str), + StateChanged(ConnectionStateType), + ServerError(EdbError), +} + +impl ConnectionStateSend for F +where + F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>, +{ + fn send(&mut self, message: EdgeDBBackendBuilder) -> Result<(), std::io::Error> { + self(ConnectionEvent::Send(message)) + } + + fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error> { + self(ConnectionEvent::Auth(user, database)) + } + + fn params(&mut self) -> Result<(), std::io::Error> { + self(ConnectionEvent::Params) + } +} + +impl ConnectionStateUpdate for F +where + F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>, +{ + fn parameter(&mut self, name: &str, value: &str) { + let _ = self(ConnectionEvent::Parameter(name, value)); + } + + fn state_changed(&mut self, state: ConnectionStateType) { + let _ = self(ConnectionEvent::StateChanged(state)); + } + + fn server_error(&mut self, error: &EdbError) { + let _ = self(ConnectionEvent::ServerError(*error)); + } +} + +#[derive(Debug, derive_more::Display, thiserror::Error)] +enum ServerError { + IO(#[from] std::io::Error), + Protocol(#[from] EdbError), + Utf8Error(#[from] Utf8Error), +} + +impl From for ServerError { + fn from(value: ServerAuthError) -> Self { + match value { + ServerAuthError::InvalidAuthorizationSpecification => { + ServerError::Protocol(EdbError::AuthenticationError) + } + ServerAuthError::InvalidPassword => { + ServerError::Protocol(EdbError::AuthenticationError) + } + ServerAuthError::InvalidSaslMessage(_) => { + ServerError::Protocol(EdbError::ProtocolError) + } + ServerAuthError::UnsupportedAuthType => { + ServerError::Protocol(EdbError::UnsupportedFeatureError) + } + ServerAuthError::InvalidMessageType => ServerError::Protocol(EdbError::ProtocolError), + } + } +} + +const PROTOCOL_ERROR: ServerError = ServerError::Protocol(EdbError::ProtocolError); +const AUTH_ERROR: ServerError = ServerError::Protocol(EdbError::AuthenticationError); +const PROTOCOL_VERSION_ERROR: ServerError = + ServerError::Protocol(EdbError::UnsupportedProtocolVersionError); + +#[derive(Debug)] +enum ServerStateImpl { + Initial, + AuthInfo(String), + Authenticating(ServerAuth), + Synchronizing, + Ready, + Error, +} + +pub struct ServerState { + state: ServerStateImpl, + buffer: StructBuffer, +} + +impl ServerState { + pub fn new() -> Self { + Self { + state: ServerStateImpl::Initial, + buffer: Default::default(), + } + } + + pub fn is_ready(&self) -> bool { + matches!(self.state, ServerStateImpl::Ready) + } + + pub fn is_error(&self) -> bool { + matches!(self.state, ServerStateImpl::Error) + } + + pub fn is_done(&self) -> bool { + self.is_ready() || self.is_error() + } + + pub fn drive( + &mut self, + drive: ConnectionDrive, + update: &mut impl ConnectionStateUpdate, + ) -> Result<(), ConnectionError> { + trace!("SERVER DRIVE: {:?} {:?}", self.state, drive); + let res = match drive { + ConnectionDrive::RawMessage(raw) => self.buffer.push_fallible(raw, |message| { + trace!("Parsed message: {message:?}"); + self.state + .drive_inner(ConnectionDrive::Message(message), update) + }), + drive => self.state.drive_inner(drive, update), + }; + + match res { + Ok(_) => Ok(()), + Err(ServerError::IO(e)) => Err(e.into()), + Err(ServerError::Utf8Error(e)) => Err(e.into()), + Err(ServerError::Protocol(code)) => { + self.state = ServerStateImpl::Error; + send_error(update, code, "Connection error")?; + Err(code.into()) + } + } + } +} + +impl ServerStateImpl { + fn drive_inner( + &mut self, + drive: ConnectionDrive, + update: &mut impl ConnectionStateUpdate, + ) -> Result<(), ServerError> { + use ServerStateImpl::*; + + match (&mut *self, drive) { + (Initial, ConnectionDrive::Message(message)) => { + match_message!(message, Message { + (ClientHandshake as handshake) => { + trace!("ClientHandshake: {handshake:?}"); + + // The handshake should generate an event rather than hardcoding the min/max protocol versions. + + // No extensions are supported + if !handshake.extensions().is_empty() { + update.send(EdgeDBBackendBuilder::ServerHandshake(builder::ServerHandshake { major_ver: 2, minor_ver: 0, extensions: &[] }))?; + return Ok(()); + } + + // We support 1.x and 2.0 + let major_ver = handshake.major_ver(); + let minor_ver = handshake.minor_ver(); + match (major_ver, minor_ver) { + (..=0, _) => { + update.send(EdgeDBBackendBuilder::ServerHandshake(builder::ServerHandshake { major_ver: 1, minor_ver: 0, extensions: &[] }))?; + return Ok(()); + } + (1, 1..) => { + // 1.(1+) never existed + return Err(PROTOCOL_VERSION_ERROR); + } + (2, 1..) | (3.., _) => { + update.send(EdgeDBBackendBuilder::ServerHandshake(builder::ServerHandshake { major_ver: 2, minor_ver: 0, extensions: &[] }))?; + return Ok(()); + } + _ => {} + } + + let mut user = String::new(); + let mut database = String::new(); + let mut branch = String::new(); + for param in handshake.params() { + match param.name().to_str()? { + "user" => user = param.value().to_owned()?, + "database" => database = param.value().to_owned()?, + "branch" => branch = param.value().to_owned()?, + _ => {} + } + update.parameter(param.name().to_str()?, param.value().to_str()?); + } + if user.is_empty() { + return Err(AUTH_ERROR.into()); + } + if database.is_empty() { + database = user.clone(); + } + *self = AuthInfo(user.clone()); + update.auth(user, database)?; + }, + unknown => { + log_unknown_message(unknown, "Initial")?; + } + }); + } + (AuthInfo(_), ConnectionDrive::AuthInfo(auth_type, credential_data)) => { + let mut auth = ServerAuth::new(String::new(), auth_type, credential_data); + match auth.drive(ServerAuthDrive::Initial) { + ServerAuthResponse::Initial(AuthType::ScramSha256, _) => { + update.send(EdgeDBBackendBuilder::AuthenticationRequiredSASLMessage( + builder::AuthenticationRequiredSASLMessage { + methods: &["SCRAM-SHA-256"], + }, + ))?; + } + ServerAuthResponse::Complete(..) => { + update.send(EdgeDBBackendBuilder::AuthenticationOk(Default::default()))?; + *self = Synchronizing; + update.params()?; + return Ok(()); + } + ServerAuthResponse::Error(e) => return Err(e.into()), + _ => return Err(PROTOCOL_ERROR), + } + *self = Authenticating(auth); + } + (Authenticating(auth), ConnectionDrive::Message(message)) => { + match_message!(message, Message { + (AuthenticationSASLInitialResponse as sasl) if auth.is_initial_message() => { + match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.sasl_data().as_ref())) { + ServerAuthResponse::Continue(final_message) => { + update.send(EdgeDBBackendBuilder::AuthenticationSASLContinue(builder::AuthenticationSASLContinue { + sasl_data: &final_message, + }))?; + } + ServerAuthResponse::Error(e) => return Err(e.into()), + _ => return Err(PROTOCOL_ERROR), + } + }, + (AuthenticationSASLResponse as sasl) if !auth.is_initial_message() => { + match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.sasl_data().as_ref())) { + ServerAuthResponse::Complete(data) => { + update.send(EdgeDBBackendBuilder::AuthenticationSASLFinal(builder::AuthenticationSASLFinal { + sasl_data: &data, + }))?; + update.send(EdgeDBBackendBuilder::AuthenticationOk(Default::default()))?; + *self = Synchronizing; + update.params()?; + } + ServerAuthResponse::Error(e) => return Err(e.into()), + _ => return Err(PROTOCOL_ERROR), + } + }, + unknown => { + log_unknown_message(unknown, "Authenticating")?; + } + }); + } + (Synchronizing, ConnectionDrive::Parameter(name, value)) => { + update.send(EdgeDBBackendBuilder::ParameterStatus( + builder::ParameterStatus { + name: name.as_bytes(), + value: value.as_bytes(), + }, + ))?; + } + (Synchronizing, ConnectionDrive::Ready(key_data)) => { + update.send(EdgeDBBackendBuilder::ServerKeyData( + builder::ServerKeyData { data: key_data }, + ))?; + update.send(EdgeDBBackendBuilder::ReadyForCommand( + builder::ReadyForCommand { + annotations: &[], + transaction_state: 0x49, + }, + ))?; + *self = Ready; + } + (_, ConnectionDrive::Fail(error, _)) => { + return Err(ServerError::Protocol(error)); + } + _ => { + error!("Unexpected drive in state {:?}", self); + return Err(PROTOCOL_ERROR); + } + } + + Ok(()) + } +} + +fn log_unknown_message( + message: Result, + state: &str, +) -> Result<(), ServerError> { + match message { + Ok(message) => { + warn!( + "Unexpected message {:?} (length {}) received in {} state", + message.mtype(), + message.mlen(), + state + ); + Ok(()) + } + Err(e) => { + error!("Corrupted message received in {} state {:?}", state, e); + Err(PROTOCOL_ERROR) + } + } +} + +fn send_error( + update: &mut impl ConnectionStateUpdate, + code: EdbError, + message: &str, +) -> std::io::Result<()> { + update.server_error(&code); + update.send(EdgeDBBackendBuilder::ErrorResponse( + builder::ErrorResponse { + severity: 0x78, + error_code: code as i32, + message, + attributes: &[], + }, + )) +} + +enum ErrorSeverity { + ERROR = 0x78, + FATAL = 0xc8, + PANIC = 0xff, +} diff --git a/rust/pgrust/src/handshake/mod.rs b/rust/pgrust/src/handshake/mod.rs index 23f0083c127..70b3bea9185 100644 --- a/rust/pgrust/src/handshake/mod.rs +++ b/rust/pgrust/src/handshake/mod.rs @@ -9,35 +9,9 @@ pub enum ConnectionSslRequirement { Required, } -/// Specifies the type of authentication or indicates the authentication method used for a connection. -#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] -pub enum AuthType { - /// Denies a login or indicates that a connection was denied. - /// - /// When used with the server, this will cause it to emulate the given - /// authentication type, but unconditionally return a failure. - /// - /// This is used for testing purposes, and to emulate timing when a user - /// does not exist. - #[default] - Deny, - /// Trusts a login without requiring authentication, or indicates - /// that a connection required no authentication. - /// - /// When used with the server side of the handshake, this will cause it to - /// emulate the given authentication type, but unconditionally succeed. - /// Not compatible with SCRAM-SHA-256 as that protocol requires server and client - /// to cryptographically agree on a password. - Trust, - /// Plain text authentication, or indicates that plain text authentication was required. - Plain, - /// MD5 password authentication, or indicates that MD5 password authentication was required. - Md5, - /// SCRAM-SHA-256 authentication, or indicates that SCRAM-SHA-256 authentication was required. - ScramSha256, -} - mod client_state_machine; +pub mod edgedb_server; +mod server_auth; mod server_state_machine; pub mod client { @@ -56,10 +30,10 @@ mod tests { ConnectionSslRequirement, }; use crate::{ + auth::{AuthType, CredentialData}, connection::Credentials, errors::{PgError, PgServerError}, - handshake::{server::CredentialData, AuthType}, - protocol::{InitialMessage, Message}, + protocol::postgres::{data::*, *}, }; use rstest::rstest; use std::collections::VecDeque; @@ -91,18 +65,12 @@ mod tests { } impl client::ConnectionStateSend for ConnectionPipe { - fn send( - &mut self, - message: crate::protocol::definition::FrontendBuilder, - ) -> Result<(), std::io::Error> { + fn send(&mut self, message: FrontendBuilder) -> Result<(), std::io::Error> { eprintln!("Client -> Server {message:?}"); self.smsg.push_back((false, message.to_vec())); Ok(()) } - fn send_initial( - &mut self, - message: crate::protocol::definition::InitialBuilder, - ) -> Result<(), std::io::Error> { + fn send_initial(&mut self, message: InitialBuilder) -> Result<(), std::io::Error> { eprintln!("Client -> Server {message:?}"); self.smsg.push_back((true, message.to_vec())); Ok(()) @@ -131,18 +99,12 @@ mod tests { self.sparams = true; Ok(()) } - fn send( - &mut self, - message: crate::protocol::definition::BackendBuilder, - ) -> Result<(), std::io::Error> { + fn send(&mut self, message: BackendBuilder) -> Result<(), std::io::Error> { eprintln!("Server -> Client {message:?}"); self.cmsg.push_back((false, message.to_vec())); Ok(()) } - fn send_ssl( - &mut self, - message: crate::protocol::builder::SSLResponse, - ) -> Result<(), std::io::Error> { + fn send_ssl(&mut self, message: builder::SSLResponse) -> Result<(), std::io::Error> { self.cmsg.push_back((true, message.to_vec())); Ok(()) } @@ -153,7 +115,7 @@ mod tests { /// We test the full matrix of server and client combinations. #[rstest] - #[test] + #[test_log::test] fn test_both( #[values( AuthType::Deny, @@ -182,8 +144,7 @@ mod tests { }, ConnectionSslRequirement::Disable, ); - let mut server = - server::ServerState::new(ConnectionSslRequirement::Disable, 0x1234, 0x4321); + let mut server = server::ServerState::new(ConnectionSslRequirement::Disable); // We test all variations here, but not all combinations will result in // valid auth, even with a correct password. @@ -231,7 +192,7 @@ mod tests { }; let data = CredentialData::new(credential_type, user.clone(), password); server_error |= server - .drive(ConnectionDrive::AuthInfo(user, auth_type, data), &mut pipe) + .drive(ConnectionDrive::AuthInfo(auth_type, data), &mut pipe) .is_err(); } if pipe.sparams { @@ -247,7 +208,9 @@ mod tests { &mut pipe, ) .is_err(); - server_error |= server.drive(ConnectionDrive::Ready, &mut pipe).is_err(); + server_error |= server + .drive(ConnectionDrive::Ready(1234, 4567), &mut pipe) + .is_err(); } while let Some((initial, msg)) = pipe.smsg.pop_front() { if initial { @@ -281,7 +244,12 @@ mod tests { } if expect_success { - assert!(client.is_ready() && server.is_ready()) + assert!( + client.is_ready() && server.is_ready(), + "client={:?} server={:?}", + client, + server + ); } else { assert!(client_error && server_error); assert!(pipe.cerror.is_some() && pipe.serror.is_some()); diff --git a/rust/pgrust/src/handshake/server_auth.rs b/rust/pgrust/src/handshake/server_auth.rs new file mode 100644 index 00000000000..2d3da6746da --- /dev/null +++ b/rust/pgrust/src/handshake/server_auth.rs @@ -0,0 +1,179 @@ +use crate::auth::{AuthType, CredentialData, SCRAMError, ServerTransaction, StoredHash, StoredKey}; +use tracing::error; + +#[derive(Debug)] +pub enum ServerAuthResponse { + Initial(AuthType, Vec), + Continue(Vec), + Complete(Vec), + Error(ServerAuthError), +} + +#[derive(Debug, thiserror::Error)] +pub enum ServerAuthError { + #[error("Invalid authorization specification")] + InvalidAuthorizationSpecification, + #[error("Invalid password")] + InvalidPassword, + #[error("Invalid SASL message ({0})")] + InvalidSaslMessage(SCRAMError), + #[error("Unsupported authentication type")] + UnsupportedAuthType, + #[error("Invalid message type")] + InvalidMessageType, +} + +#[derive(Debug)] +enum ServerAuthState { + Initial, + Password(CredentialData), + MD5([u8; 4], CredentialData), + SASL(ServerTransaction, StoredKey), +} + +#[derive(Debug)] +pub enum ServerAuthDrive<'a> { + Initial, + Message(AuthType, &'a [u8]), +} + +#[derive(Debug)] +pub struct ServerAuth { + state: ServerAuthState, + username: String, + auth_type: AuthType, + credential_data: CredentialData, +} + +impl ServerAuth { + pub fn new(username: String, auth_type: AuthType, credential_data: CredentialData) -> Self { + Self { + state: ServerAuthState::Initial, + username, + auth_type, + credential_data, + } + } + + pub fn is_initial_message(&self) -> bool { + match &self.state { + ServerAuthState::Initial => false, + ServerAuthState::SASL(tx, _) => tx.initial(), + _ => true, + } + } + + pub fn auth_type(&self) -> AuthType { + self.auth_type + } + + pub fn drive(&mut self, drive: ServerAuthDrive) -> ServerAuthResponse { + match (&mut self.state, drive) { + (ServerAuthState::Initial, ServerAuthDrive::Initial) => self.handle_initial(), + (ServerAuthState::Password(data), ServerAuthDrive::Message(AuthType::Plain, input)) => { + let client_password = input; + let success = match data { + CredentialData::Deny => false, + CredentialData::Trust => true, + CredentialData::Plain(password) => client_password == password.as_bytes(), + CredentialData::Md5(md5) => { + let md5_1 = StoredHash::generate(client_password, &self.username); + md5_1 == *md5 + } + CredentialData::Scram(scram) => { + let key = + StoredKey::generate(client_password, &scram.salt, scram.iterations); + key.stored_key == scram.stored_key + } + }; + if success { + ServerAuthResponse::Complete(Vec::new()) + } else { + ServerAuthResponse::Error(ServerAuthError::InvalidPassword) + } + } + (ServerAuthState::MD5(salt, data), ServerAuthDrive::Message(AuthType::Md5, input)) => { + let success = match data { + CredentialData::Deny => false, + CredentialData::Trust => true, + CredentialData::Plain(password) => { + let server_md5 = StoredHash::generate(password.as_bytes(), &self.username); + server_md5.matches(input, *salt) + } + CredentialData::Md5(server_md5) => server_md5.matches(input, *salt), + CredentialData::Scram(_) => { + // Unreachable + false + } + }; + + if success { + ServerAuthResponse::Complete(Vec::new()) + } else { + ServerAuthResponse::Error(ServerAuthError::InvalidPassword) + } + } + ( + ServerAuthState::SASL(tx, data), + ServerAuthDrive::Message(AuthType::ScramSha256, input), + ) => { + let initial = tx.initial(); + match tx.process_message(input, data) { + Ok(final_message) => { + if initial { + ServerAuthResponse::Continue(final_message) + } else { + ServerAuthResponse::Complete(final_message) + } + } + Err(e) => ServerAuthResponse::Error(ServerAuthError::InvalidSaslMessage(e)), + } + } + (_, drive) => { + error!("Received invalid drive {drive:?} in state {:?}", self.state); + ServerAuthResponse::Error(ServerAuthError::InvalidMessageType) + } + } + } + + fn handle_initial(&mut self) -> ServerAuthResponse { + match self.auth_type { + AuthType::Deny => { + ServerAuthResponse::Error(ServerAuthError::InvalidAuthorizationSpecification) + } + AuthType::Trust => ServerAuthResponse::Complete(Vec::new()), + AuthType::Plain => { + self.state = ServerAuthState::Password(self.credential_data.clone()); + ServerAuthResponse::Initial(AuthType::Plain, Vec::new()) + } + AuthType::Md5 => { + let salt: [u8; 4] = rand::random(); + match self.credential_data { + CredentialData::Scram(..) => { + ServerAuthResponse::Error(ServerAuthError::UnsupportedAuthType) + } + _ => { + self.state = ServerAuthState::MD5(salt, self.credential_data.clone()); + ServerAuthResponse::Initial(AuthType::Md5, salt.into()) + } + } + } + AuthType::ScramSha256 => { + let salt: [u8; 32] = rand::random(); + let scram = match &self.credential_data { + CredentialData::Scram(scram) => scram.clone(), + CredentialData::Plain(password) => { + StoredKey::generate(password.as_bytes(), &salt, 4096) + } + CredentialData::Deny => StoredKey::generate(b"", &salt, 4096), + _ => { + return ServerAuthResponse::Error(ServerAuthError::UnsupportedAuthType); + } + }; + let tx = ServerTransaction::default(); + self.state = ServerAuthState::SASL(tx, scram); + ServerAuthResponse::Initial(AuthType::ScramSha256, Vec::new()) + } + } + } +} diff --git a/rust/pgrust/src/handshake/server_state_machine.rs b/rust/pgrust/src/handshake/server_state_machine.rs index bf3b728e65e..129286bb877 100644 --- a/rust/pgrust/src/handshake/server_state_machine.rs +++ b/rust/pgrust/src/handshake/server_state_machine.rs @@ -1,18 +1,21 @@ -use super::ConnectionSslRequirement; +use super::{ + server_auth::{ServerAuth, ServerAuthError}, + ConnectionSslRequirement, +}; use crate::{ - auth::{ServerTransaction, StoredHash, StoredKey}, + auth::{AuthType, CredentialData}, connection::ConnectionError, errors::{ PgError, PgErrorConnectionException, PgErrorFeatureNotSupported, PgErrorInvalidAuthorizationSpecification, PgServerError, PgServerErrorField, }, - handshake::AuthType, + handshake::server_auth::{ServerAuthDrive, ServerAuthResponse}, protocol::{ - builder, definition::BackendBuilder, match_message, InitialMessage, Message, ParseError, - PasswordMessage, SASLInitialResponse, SASLResponse, SSLRequest, StartupMessage, + match_message, + postgres::{data::*, *}, + ParseError, StructBuffer, }, }; -use rand::Rng; use std::str::Utf8Error; use tracing::{error, trace, warn}; @@ -25,62 +28,10 @@ pub enum ConnectionStateType { Ready, } -#[derive(Debug, Clone)] -pub struct ServerCredentials { - pub auth_type: AuthType, - pub credential_data: CredentialData, -} - -#[derive(Debug, Clone)] -pub enum CredentialData { - /// A credential that always succeeds, regardless of input password. Due to - /// the design of SCRAM-SHA-256, this cannot be used with that auth type. - Trust, - /// A credential that always fails, regardless of the input password. - Deny, - /// A plain-text password. - Plain(String), - /// A stored MD5 hash + salt. - Md5(StoredHash), - /// A stored SCRAM-SHA-256 key. - Scram(StoredKey), -} - -impl CredentialData { - pub fn new(ty: AuthType, username: String, password: String) -> Self { - match ty { - AuthType::Deny => Self::Deny, - AuthType::Trust => Self::Trust, - AuthType::Plain => Self::Plain(password), - AuthType::Md5 => Self::Md5(StoredHash::generate(password.as_bytes(), &username)), - AuthType::ScramSha256 => { - let salt: [u8; 32] = rand::thread_rng().gen(); - Self::Scram(StoredKey::generate(password.as_bytes(), &salt, 4096)) - } - } - } -} - -/// Internal flag used to store a predetermined result: ie, a connection that -/// must succeed for fail regardless of the correctness of the credential. -/// -/// Used for testing purposes, and to disguise timing in cases where a user may -/// not exist. -#[derive(Debug, Clone, Eq, PartialEq)] -enum PredeterminedResult { - Trust, - Deny, -} - -#[derive(Debug)] -struct ServerEnvironmentImpl { - ssl_requirement: ConnectionSslRequirement, - pid: i32, - key: i32, -} - #[derive(Debug)] pub enum ConnectionDrive<'a> { + /// Raw bytes from a client. + RawMessage(&'a [u8]), /// Initial message from client. Initial(Result, ParseError>), /// Non-initial message from the client. @@ -94,11 +45,11 @@ pub enum ConnectionDrive<'a> { /// Additionally, the environment can provide a "Trust" credential for automatic /// success or a "Deny" credential for automatic failure. The server will simulate /// a login process before unconditionally succeeding or failing in these cases. - AuthInfo(String, AuthType, CredentialData), + AuthInfo(AuthType, CredentialData), /// Once authorized, the server may sync any number of parameters until ready. Parameter(String, String), /// Ready, handshake complete. - Ready, + Ready(i32, i32), /// Fail the connection with a Postgres error code and message. Fail(PgError, &'a str), } @@ -111,7 +62,7 @@ pub trait ConnectionStateSend { /// Perform the SSL upgrade. fn upgrade(&mut self) -> Result<(), std::io::Error>; /// Notify the environment that a user and database were selected. - fn auth(&mut self, user: String, data: String) -> Result<(), std::io::Error>; + fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error>; /// Notify the environment that parameters are requested. fn params(&mut self) -> Result<(), std::io::Error>; } @@ -124,20 +75,70 @@ pub trait ConnectionStateUpdate: ConnectionStateSend { fn server_error(&mut self, error: &PgServerError) {} } +#[derive(Debug)] +pub enum ConnectionEvent<'a> { + SendSSL(builder::SSLResponse<'a>), + Send(BackendBuilder<'a>), + Upgrade, + Auth(String, String), + Params, + Parameter(&'a str, &'a str), + StateChanged(ConnectionStateType), + ServerError(&'a PgServerError), +} + +impl ConnectionStateSend for F +where + F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>, +{ + fn send_ssl(&mut self, message: builder::SSLResponse) -> Result<(), std::io::Error> { + self(ConnectionEvent::SendSSL(message)) + } + + fn send(&mut self, message: BackendBuilder) -> Result<(), std::io::Error> { + self(ConnectionEvent::Send(message)) + } + + fn upgrade(&mut self) -> Result<(), std::io::Error> { + self(ConnectionEvent::Upgrade) + } + + fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error> { + self(ConnectionEvent::Auth(user, database)) + } + + fn params(&mut self) -> Result<(), std::io::Error> { + self(ConnectionEvent::Params) + } +} + +impl ConnectionStateUpdate for F +where + F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>, +{ + fn parameter(&mut self, name: &str, value: &str) { + let _ = self(ConnectionEvent::Parameter(name, value)); + } + + fn state_changed(&mut self, state: ConnectionStateType) { + let _ = self(ConnectionEvent::StateChanged(state)); + } + + fn server_error(&mut self, error: &PgServerError) { + let _ = self(ConnectionEvent::ServerError(error)); + } +} + #[derive(Debug)] enum ServerStateImpl { - /// Initial state, boolean indicates whether SSL is required - Initial(bool), + /// Initial state, enum indicates whether SSL is required (or None if enabled) + Initial(Option), /// SSL connection is being established SslConnecting, + /// Waiting for AuthInfo + AuthInfo(String), /// Authentication process has begun - Authenticating, - /// Password-based authentication in progress - AuthenticatingPassword(String, CredentialData), - /// MD5 authentication in progress - AuthenticatingMD5(Option, [u8; 4], StoredHash), - /// SASL authentication in progress - AuthenticatingSASL(ServerTransaction, Option, StoredKey), + Authenticating(ServerAuth), /// Synchronizing connection parameters Synchronizing, /// Connection is ready for queries @@ -146,9 +147,13 @@ enum ServerStateImpl { Error, } +#[derive(derive_more::Debug)] pub struct ServerState { state: ServerStateImpl, - environment: ServerEnvironmentImpl, + #[debug(skip)] + initial_buffer: StructBuffer, + #[debug(skip)] + buffer: StructBuffer, } fn send_error( @@ -187,29 +192,48 @@ enum ServerError { Utf8Error(#[from] Utf8Error), } +impl From for ServerError { + fn from(value: ServerAuthError) -> Self { + match value { + ServerAuthError::InvalidAuthorizationSpecification => { + ServerError::Protocol(PgError::InvalidAuthorizationSpecification( + PgErrorInvalidAuthorizationSpecification::InvalidAuthorizationSpecification, + )) + } + ServerAuthError::InvalidPassword => { + ServerError::Protocol(PgError::InvalidAuthorizationSpecification( + PgErrorInvalidAuthorizationSpecification::InvalidPassword, + )) + } + ServerAuthError::InvalidSaslMessage(_) => ServerError::Protocol( + PgError::ConnectionException(PgErrorConnectionException::ProtocolViolation), + ), + ServerAuthError::UnsupportedAuthType => ServerError::Protocol( + PgError::FeatureNotSupported(PgErrorFeatureNotSupported::FeatureNotSupported), + ), + ServerAuthError::InvalidMessageType => ServerError::Protocol( + PgError::ConnectionException(PgErrorConnectionException::ProtocolViolation), + ), + } + } +} + const PROTOCOL_ERROR: ServerError = ServerError::Protocol(PgError::ConnectionException( PgErrorConnectionException::ProtocolViolation, )); const AUTH_ERROR: ServerError = ServerError::Protocol(PgError::InvalidAuthorizationSpecification( PgErrorInvalidAuthorizationSpecification::InvalidAuthorizationSpecification, )); -const PASSWORD_ERROR: ServerError = - ServerError::Protocol(PgError::InvalidAuthorizationSpecification( - PgErrorInvalidAuthorizationSpecification::InvalidPassword, - )); const PROTOCOL_VERSION_ERROR: ServerError = ServerError::Protocol(PgError::FeatureNotSupported( PgErrorFeatureNotSupported::FeatureNotSupported, )); impl ServerState { - pub fn new(ssl_requirement: ConnectionSslRequirement, pid: i32, key: i32) -> Self { + pub fn new(ssl_requirement: ConnectionSslRequirement) -> Self { Self { - state: ServerStateImpl::Initial(false), - environment: ServerEnvironmentImpl { - ssl_requirement, - pid, - key, - }, + state: ServerStateImpl::Initial(Some(ssl_requirement)), + initial_buffer: Default::default(), + buffer: Default::default(), } } @@ -230,7 +254,26 @@ impl ServerState { drive: ConnectionDrive, update: &mut impl ConnectionStateUpdate, ) -> Result<(), ConnectionError> { - match self.drive_inner(drive, update) { + trace!("SERVER DRIVE: {:?} {:?}", self.state, drive); + let res = match drive { + ConnectionDrive::RawMessage(raw) => match self.state { + ServerStateImpl::Initial(..) => self.initial_buffer.push_fallible(raw, |message| { + self.state + .drive_inner(ConnectionDrive::Initial(message), update) + }), + ServerStateImpl::Authenticating(..) => self.buffer.push_fallible(raw, |message| { + self.state + .drive_inner(ConnectionDrive::Message(message), update) + }), + _ => { + error!("Unexpected drive in state {:?}", self.state); + Err(PROTOCOL_ERROR) + } + }, + drive => self.state.drive_inner(drive, update), + }; + + match res { Ok(_) => Ok(()), Err(ServerError::IO(e)) => Err(e.into()), Err(ServerError::Utf8Error(e)) => Err(e.into()), @@ -241,7 +284,9 @@ impl ServerState { } } } +} +impl ServerStateImpl { fn drive_inner( &mut self, drive: ConnectionDrive, @@ -249,8 +294,8 @@ impl ServerState { ) -> Result<(), ServerError> { use ServerStateImpl::*; - match (&mut self.state, drive) { - (Initial(ssl_active), ConnectionDrive::Initial(initial_message)) => { + match (&mut *self, drive) { + (Initial(ssl), ConnectionDrive::Initial(initial_message)) => { match_message!(initial_message, InitialMessage { (StartupMessage as startup) => { let mut user = String::new(); @@ -265,26 +310,24 @@ impl ServerState { update.parameter(param.name().to_str()?, param.value().to_str()?); } if user.is_empty() { - // Postgres returns invalid_authorization_specification if no user is specified return Err(AUTH_ERROR); } if database.is_empty() { - // Postgres uses the username as the database if not specified database = user.clone(); } - self.state = Authenticating; + *self = AuthInfo(user.clone()); update.auth(user, database)?; }, (SSLRequest) => { - if *ssl_active { + let Some(ssl) = *ssl else { return Err(PROTOCOL_ERROR); - } - if self.environment.ssl_requirement == ConnectionSslRequirement::Disable { + }; + if ssl == ConnectionSslRequirement::Disable { update.send_ssl(builder::SSLResponse { code: b'N' })?; update.upgrade()?; } else { update.send_ssl(builder::SSLResponse { code: b'S' })?; - self.state = SslConnecting; + *self = SslConnecting; } }, unknown => { @@ -293,224 +336,137 @@ impl ServerState { }); } (SslConnecting, ConnectionDrive::SslReady) => { - self.state = Initial(true); + *self = Initial(None); } (SslConnecting, _) => { return Err(PROTOCOL_ERROR); } - (Authenticating, ConnectionDrive::AuthInfo(username, auth_type, credential_data)) => { - match auth_type { - AuthType::Deny => { - return Err(AUTH_ERROR); - } - AuthType::Trust => { - update.send(BackendBuilder::AuthenticationOk(Default::default()))?; - self.state = Synchronizing; - update.params()?; - } - AuthType::Plain => { + (AuthInfo(username), ConnectionDrive::AuthInfo(auth_type, credential_data)) => { + let mut auth = ServerAuth::new(username.clone(), auth_type, credential_data); + match auth.drive(ServerAuthDrive::Initial) { + ServerAuthResponse::Initial(AuthType::Plain, _) => { update.send(BackendBuilder::AuthenticationCleartextPassword( Default::default(), ))?; - self.state = AuthenticatingPassword(username, credential_data); } - AuthType::Md5 => { - let salt = rand::random(); - let (result, hash) = match credential_data { - CredentialData::Deny => { - let md5 = StoredHash::generate(b"", &username); - (Some(PredeterminedResult::Deny), md5) - } - CredentialData::Trust => { - let md5 = StoredHash::generate(b"", &username); - (Some(PredeterminedResult::Trust), md5) - } - CredentialData::Md5(md5) => (None, md5), - CredentialData::Plain(password) => { - let md5 = StoredHash::generate(password.as_bytes(), &username); - (None, md5) - } - CredentialData::Scram(..) => { - return Err(AUTH_ERROR); - } - }; - self.state = AuthenticatingMD5(result, salt, hash); + ServerAuthResponse::Initial(AuthType::Md5, salt) => { update.send(BackendBuilder::AuthenticationMD5Password( - builder::AuthenticationMD5Password { salt }, + builder::AuthenticationMD5Password { + salt: salt.try_into().unwrap(), + }, ))?; } - AuthType::ScramSha256 => { - let salt: [u8; 32] = rand::random(); - match credential_data { - CredentialData::Trust | CredentialData::Md5(..) => { - return Err(AUTH_ERROR); - } - CredentialData::Deny => { - // Create fake scram data - let scram = StoredKey::generate("".as_bytes(), &salt, 4096); - self.state = AuthenticatingSASL( - ServerTransaction::default(), - Some(PredeterminedResult::Deny), - scram, - ); - } - CredentialData::Plain(password) => { - // Upgrade password to SCRAM - let scram = StoredKey::generate(password.as_bytes(), &salt, 4096); - self.state = - AuthenticatingSASL(ServerTransaction::default(), None, scram); - } - CredentialData::Scram(scram) => { - self.state = - AuthenticatingSASL(ServerTransaction::default(), None, scram); - } - } + ServerAuthResponse::Initial(AuthType::ScramSha256, _) => { update.send(BackendBuilder::AuthenticationSASL( builder::AuthenticationSASL { mechanisms: &["SCRAM-SHA-256"], }, ))?; } + ServerAuthResponse::Complete(..) => { + update.send(BackendBuilder::AuthenticationOk(Default::default()))?; + *self = Synchronizing; + update.params()?; + return Ok(()); + } + ServerAuthResponse::Error(e) => { + error!("Authentication error in initial state: {e:?}"); + return Err(e.into()); + } + response => { + error!("Unexpected response: {response:?}"); + return Err(PROTOCOL_ERROR); + } } + *self = Authenticating(auth); } - (AuthenticatingPassword(username, data), ConnectionDrive::Message(message)) => { + (Authenticating(auth), ConnectionDrive::Message(message)) => { + trace!("auth = {auth:?}, initial = {}", auth.is_initial_message()); match_message!(message, Message { - (PasswordMessage as password) => { - let client_password = password.password(); - let success = match data { - CredentialData::Deny => { - false - }, - CredentialData::Trust => { - true - }, - CredentialData::Plain(password) => { - let md5_1 = StoredHash::generate(password.as_bytes(), username); - let md5_2 = StoredHash::generate(client_password.to_bytes(), username); - md5_1 == md5_2 + (PasswordMessage as password) if matches!(auth.auth_type(), AuthType::Plain | AuthType::Md5) => { + match auth.drive(ServerAuthDrive::Message(auth.auth_type(), password.password().to_bytes())) { + ServerAuthResponse::Complete(..) => { + update.send(BackendBuilder::AuthenticationOk(Default::default()))?; + *self = Synchronizing; + update.params()?; } - CredentialData::Md5(md5) => { - let md5_1 = StoredHash::generate(client_password.to_bytes(), username); - md5_1 == *md5 + ServerAuthResponse::Error(e) => { + error!("Authentication error for password message: {e:?}"); + return Err(e.into()) }, - CredentialData::Scram(scram) => { - // We can test a password by hashing it with the same salt and iteration count - let key = StoredKey::generate(client_password.to_bytes(), &scram.salt, scram.iterations); - key.stored_key == scram.stored_key + response => { + error!("Unexpected response for password message: {response:?}"); + return Err(PROTOCOL_ERROR); } - }; - if success { - update.send(BackendBuilder::AuthenticationOk(Default::default()))?; - self.state = Synchronizing; - update.params()?; - } else { - return Err(PASSWORD_ERROR); } }, - unknown => { - log_unknown_message(unknown, "Password")?; - } - }); - } - (AuthenticatingMD5(results, salt, md5), ConnectionDrive::Message(message)) => { - match_message!(message, Message { - (PasswordMessage as password) => { - let password = password.password(); - let success = match (results, md5) { - (Some(PredeterminedResult::Deny), _) => { - false - }, - (Some(PredeterminedResult::Trust), _) => { - true + (SASLInitialResponse as sasl) if auth.is_initial_message() => { + if sasl.mechanism() != "SCRAM-SHA-256" { + error!("Unexpected mechanism: {:?}", sasl.mechanism()); + return Err(PROTOCOL_ERROR); + } + match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.response().as_ref())) { + ServerAuthResponse::Continue(final_message) => { + update.send(BackendBuilder::AuthenticationSASLContinue(builder::AuthenticationSASLContinue { + data: &final_message, + }))?; + } + ServerAuthResponse::Error(e) => { + error!("Authentication error for SASL initial response: {e:?}"); + return Err(e.into()) }, - (None, md5) => { - md5.matches(password.to_bytes(), *salt) + response => { + error!("Unexpected response for SASL initial response: {response:?}"); + return Err(PROTOCOL_ERROR); + } + } + }, + (SASLResponse as sasl) if !auth.is_initial_message() => { + match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.response().as_ref())) { + ServerAuthResponse::Complete(data) => { + update.send(BackendBuilder::AuthenticationSASLFinal(builder::AuthenticationSASLFinal { + data: &data, + }))?; + update.send(BackendBuilder::AuthenticationOk(Default::default()))?; + *self = Synchronizing; + update.params()?; + } + ServerAuthResponse::Error(e) => { + error!("Authentication error for SASL response: {e:?}"); + return Err(e.into()) }, - }; - if success { - update.send(BackendBuilder::AuthenticationOk(Default::default()))?; - self.state = Synchronizing; - update.params()?; - } else { - return Err(PASSWORD_ERROR); + response => { + error!("Unexpected response for SASL response: {response:?}"); + return Err(PROTOCOL_ERROR); + } } }, unknown => { - log_unknown_message(unknown, "MD5")?; + log_unknown_message(unknown, "Authenticating")?; } }); } - (AuthenticatingSASL(tx, result, data), ConnectionDrive::Message(message)) => { - if tx.initial() { - match_message!(message, Message { - (SASLInitialResponse as sasl) => { - match tx.process_message(sasl.response().as_ref(), data) { - Ok(Some(final_message)) => { - update.send(BackendBuilder::AuthenticationSASLContinue(builder::AuthenticationSASLContinue { - data: &final_message, - }))?; - }, - Ok(None) => return Err(PASSWORD_ERROR), - Err(e) => { - error!("SCRAM auth failed: {e:?}"); - return Err(PASSWORD_ERROR); - } - } - }, - unknown => { - warn!("Protocol error: unknown or malformed message: {unknown:?}"); - return Err(PROTOCOL_ERROR); - } - }); - } else { - match_message!(message, Message { - (SASLResponse as sasl) => { - match tx.process_message(sasl.response().as_ref(), data) { - Ok(Some(final_message)) => { - if *result == Some(PredeterminedResult::Deny) { - return Err(PASSWORD_ERROR) - } - self.state = Synchronizing; - update.send(BackendBuilder::AuthenticationSASLFinal(builder::AuthenticationSASLFinal { - data: &final_message, - }))?; - update.send(BackendBuilder::AuthenticationOk(Default::default()))?; - update.params()?; - }, - Ok(None) => return Err(PASSWORD_ERROR), - Err(e) => { - error!("SCRAM auth failed: {e:?}"); - return Err(PASSWORD_ERROR); - } - } - }, - unknown => { - log_unknown_message(unknown, "SASL")?; - } - }); - }; - } (Synchronizing, ConnectionDrive::Parameter(name, value)) => { update.send(BackendBuilder::ParameterStatus(builder::ParameterStatus { name: &name, value: &value, }))?; } - (Synchronizing, ConnectionDrive::Ready) => { + (Synchronizing, ConnectionDrive::Ready(pid, key)) => { update.send(BackendBuilder::BackendKeyData(builder::BackendKeyData { - key: self.environment.key, - pid: self.environment.pid, + pid, + key, }))?; update.send(BackendBuilder::ReadyForQuery(builder::ReadyForQuery { status: b'I', }))?; - self.state = Ready; + *self = Ready; } (_, ConnectionDrive::Fail(error, _)) => { return Err(ServerError::Protocol(error)); } _ => { - error!("Unexpected drive in state {:?}", self.state); + error!("Unexpected drive in state {:?}", self); return Err(PROTOCOL_ERROR); } } diff --git a/rust/pgrust/src/protocol/buffer.rs b/rust/pgrust/src/protocol/buffer.rs index dcbb10999c4..6837407b41a 100644 --- a/rust/pgrust/src/protocol/buffer.rs +++ b/rust/pgrust/src/protocol/buffer.rs @@ -142,12 +142,21 @@ impl StructBuffer { pub fn into_inner(self) -> VecDeque { self.accum } + + pub fn is_empty(&self) -> bool { + self.accum.is_empty() + } + + pub fn len(&self) -> usize { + self.accum.len() + } } #[cfg(test)] mod tests { use super::StructBuffer; - use crate::protocol::{builder, meta, Encoded, Message, ParseError}; + use crate::protocol::postgres::{builder, data::*, meta}; + use crate::protocol::*; /// Create a test data buffer containing three messages fn test_data() -> (Vec, Vec) { diff --git a/rust/pgrust/src/protocol/datatypes.rs b/rust/pgrust/src/protocol/datatypes.rs index 3a2eb9de787..adb6cae2799 100644 --- a/rust/pgrust/src/protocol/datatypes.rs +++ b/rust/pgrust/src/protocol/datatypes.rs @@ -1,5 +1,7 @@ use std::{marker::PhantomData, str::Utf8Error}; +use uuid::Uuid; + use super::{ arrays::{array_access, Array, ArrayMeta}, field_access, @@ -9,8 +11,10 @@ use super::{ pub mod meta { pub use super::EncodedMeta as Encoded; + pub use super::LStringMeta as LString; pub use super::LengthMeta as Length; pub use super::RestMeta as Rest; + pub use super::UuidMeta as Uuid; pub use super::ZTStringMeta as ZTString; } @@ -201,6 +205,178 @@ impl FieldAccess { } } +/// A length-prefixed string. +#[allow(unused)] +pub struct LString<'a> { + buf: &'a [u8], +} + +field_access!(LStringMeta); +array_access!(LStringMeta); + +pub struct LStringMeta {} +impl Meta for LStringMeta { + fn name(&self) -> &'static str { + "LString" + } +} + +impl Enliven for LStringMeta { + type WithLifetime<'a> = LString<'a>; + type ForMeasure<'a> = &'a str; + type ForBuilder<'a> = &'a str; +} + +impl std::fmt::Debug for LString<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + String::from_utf8_lossy(self.buf).fmt(f) + } +} + +impl<'a> LString<'a> { + pub fn to_owned(&self) -> Result { + std::str::from_utf8(self.buf).map(|s| s.to_owned()) + } + + pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.buf) + } + + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + String::from_utf8_lossy(self.buf) + } + + pub fn to_bytes(&self) -> &[u8] { + self.buf + } +} + +impl PartialEq for LString<'_> { + fn eq(&self, other: &Self) -> bool { + self.buf == other.buf + } +} +impl Eq for LString<'_> {} + +impl PartialEq for LString<'_> { + fn eq(&self, other: &str) -> bool { + self.buf == other.as_bytes() + } +} + +impl PartialEq<&str> for LString<'_> { + fn eq(&self, other: &&str) -> bool { + self.buf == other.as_bytes() + } +} + +impl<'a> TryInto<&'a str> for LString<'a> { + type Error = Utf8Error; + fn try_into(self) -> Result<&'a str, Self::Error> { + std::str::from_utf8(self.buf) + } +} + +impl FieldAccess { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + &LStringMeta {} + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + if buf.len() < 4 { + return Err(ParseError::TooShort); + } + let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + Ok(4 + len) + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + if buf.len() < 4 { + return Err(ParseError::TooShort); + } + let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + if buf.len() < 4 + len { + return Err(ParseError::TooShort); + } + Ok(LString { + buf: buf.split_at(4).1, + }) + } + #[inline(always)] + pub const fn measure(buf: &str) -> usize { + 4 + buf.len() + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { + let len = value.len() as u32; + buf.write(&len.to_be_bytes()); + buf.write(value.as_bytes()); + } + #[inline(always)] + pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &str) { + let len = value.len() as u32; + buf.write(&len.to_be_bytes()); + buf.write(value.as_bytes()); + } +} + +field_access!(UuidMeta); +array_access!(UuidMeta); + +pub struct UuidMeta {} +impl Meta for UuidMeta { + fn name(&self) -> &'static str { + "Uuid" + } +} + +impl Enliven for UuidMeta { + type WithLifetime<'a> = Uuid; + type ForMeasure<'a> = Uuid; + type ForBuilder<'a> = Uuid; +} + +impl FieldAccess { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + &UuidMeta {} + } + + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + if buf.len() < 16 { + Err(ParseError::TooShort) + } else { + Ok(16) + } + } + + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Result { + if let Some(bytes) = buf.first_chunk() { + Ok(Uuid::from_u128(::from_be_bytes(*bytes))) + } else { + Err(ParseError::TooShort) + } + } + + #[inline(always)] + pub const fn measure(_value: &Uuid) -> usize { + 16 + } + + #[inline(always)] + pub fn copy_to_buf(buf: &mut BufWriter, value: Uuid) { + buf.write(value.as_bytes().as_slice()); + } + + #[inline(always)] + pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &Uuid) { + buf.write(&value.as_bytes().as_slice()); + } +} + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] /// An encoded row value. pub enum Encoded<'a> { @@ -514,7 +690,7 @@ macro_rules! basic_types { } } - basic_types!(: array<$ty> u8 i16 i32); + basic_types!(: array<$ty> u8 i16 i32 u32 u64); )* }; @@ -600,4 +776,4 @@ macro_rules! basic_types { )* } } -basic_types!(u8 i16 i32); +basic_types!(u8 i16 i32 u32 u64); diff --git a/rust/pgrust/src/protocol/edgedb.rs b/rust/pgrust/src/protocol/edgedb.rs new file mode 100644 index 00000000000..98a25c2155b --- /dev/null +++ b/rust/pgrust/src/protocol/edgedb.rs @@ -0,0 +1,503 @@ +use super::gen::protocol; +use crate::protocol::message_group::message_group; +message_group!( + EdgeDBBackend: Message = [ + AuthenticationOk, + AuthenticationRequiredSASLMessage, + AuthenticationSASLContinue, + AuthenticationSASLFinal, + ServerKeyData, + ParameterStatus, + ServerHandshake, + ReadyForCommand, + RestoreReady, + CommandComplete, + CommandDataDescription, + StateDataDescription, + Data, + DumpHeader, + DumpBlock, + ErrorResponse, + LogMessage + ] +); + +message_group!( + EdgeDBFrontend: Message = [ + ClientHandshake, + AuthenticationSASLInitialResponse, + AuthenticationSASLResponse, + Parse, + Execute, + Sync, + Flush, + Terminate, + Dump, + Restore, + RestoreBlock, + RestoreEof + ] +); + +protocol!( + +/// A generic base for all EdgeDB mtype/mlen-style messages. +struct Message { + /// Identifies the message. + mtype: u8, + /// Length of message contents in bytes, including self. + mlen: len, + /// Message contents. + data: Rest, +} + +/// The `ErrorResponse` struct represents an error message sent from the server. +struct ErrorResponse: Message { + /// Identifies the message as an error response. + mtype: u8 = 'E', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message severity. + severity: u8, + /// Message code. + error_code: i32, + /// Error message. + message: LString, + /// Error attributes. + attributes: Array, +} + +/// The `LogMessage` struct represents a log message sent from the server. +struct LogMessage: Message { + /// Identifies the message as a log message. + mtype: u8 = 'L', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message severity. + severity: u8, + /// Message code. + code: i32, + /// Message text. + text: LString, + /// Message annotations. + annotations: Array, +} + +/// The `ReadyForCommand` struct represents a message indicating the server is ready for a new command. +struct ReadyForCommand: Message { + /// Identifies the message as ready for command. + mtype: u8 = 'Z', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message annotations. + annotations: Array, + /// Transaction state. + transaction_state: u8, +} + +/// The `RestoreReady` struct represents a message indicating the server is ready for restore. +struct RestoreReady: Message { + /// Identifies the message as restore ready. + mtype: u8 = '+', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message annotations. + annotations: Array, + /// Number of parallel jobs for restore. + jobs: i16, +} + +/// The `CommandComplete` struct represents a message indicating a command has completed. +struct CommandComplete: Message { + /// Identifies the message as command complete. + mtype: u8 = 'C', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message annotations. + annotations: Array, + /// A bit mask of allowed capabilities. + capabilities: u64, + /// Command status. + status: LString, + /// State data descriptor ID. + state_typedesc_id: Uuid, + /// Encoded state data. + state_data: Array, +} + +/// The `CommandDataDescription` struct represents a description of command data. +struct CommandDataDescription: Message { + /// Identifies the message as command data description. + mtype: u8 = 'T', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message annotations. + annotations: Array, + /// A bit mask of allowed capabilities. + capabilities: u64, + /// Actual result cardinality. + result_cardinality: u8, + /// Argument data descriptor ID. + input_typedesc_id: Uuid, + /// Argument data descriptor. + input_typedesc: Array, + /// Output data descriptor ID. + output_typedesc_id: Uuid, + /// Output data descriptor. + output_typedesc: Array, +} + +/// The `StateDataDescription` struct represents a description of state data. +struct StateDataDescription: Message { + /// Identifies the message as state data description. + mtype: u8 = 's', + /// Length of message contents in bytes, including self. + mlen: len, + /// Updated state data descriptor ID. + typedesc_id: Uuid, + /// State data descriptor. + typedesc: Array, +} + +/// The `Data` struct represents a data message. +struct Data: Message { + /// Identifies the message as data. + mtype: u8 = 'D', + /// Length of message contents in bytes, including self. + mlen: len, + /// Encoded output data array. + data: Array, +} + +/// The `DumpHeader` struct represents a dump header message. +struct DumpHeader: Message { + /// Identifies the message as dump header. + mtype: u8 = '@', + /// Length of message contents in bytes, including self. + mlen: len, + /// Dump attributes. + attributes: Array, + /// Major version of EdgeDB. + major_ver: i16, + /// Minor version of EdgeDB. + minor_ver: i16, + /// Schema. + schema_ddl: LString, + /// Type identifiers. + types: Array, + /// Object descriptors. + descriptors: Array, +} + +/// The `DumpBlock` struct represents a dump block message. +struct DumpBlock: Message { + /// Identifies the message as dump block. + mtype: u8 = '=', + /// Length of message contents in bytes, including self. + mlen: len, + /// Dump attributes. + attributes: Array, +} + +/// The `ServerKeyData` struct represents server key data. +struct ServerKeyData: Message { + /// Identifies the message as server key data. + mtype: u8 = 'K', + /// Length of message contents in bytes, including self. + mlen: len, + /// Key data. + data: [u8; 32], +} + +/// The `ParameterStatus` struct represents a parameter status message. +struct ParameterStatus: Message { + /// Identifies the message as parameter status. + mtype: u8 = 'S', + /// Length of message contents in bytes, including self. + mlen: len, + /// Parameter name. + name: Array, + /// Parameter value. + value: Array, +} + +/// The `ServerHandshake` struct represents a server handshake message. +struct ServerHandshake: Message { + /// Identifies the message as server handshake. + mtype: u8 = 'v', + /// Length of message contents in bytes, including self. + mlen: len, + /// Maximum supported or client-requested protocol major version. + major_ver: i16, + /// Maximum supported or client-requested protocol minor version. + minor_ver: i16, + /// Supported protocol extensions. + extensions: Array, +} + +/// The `AuthenticationOk` struct represents a successful authentication message. +struct AuthenticationOk: Message { + /// Identifies the message as authentication OK. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that this message contains a successful authentication indicator. + auth_status: i32 = 0x0, +} + +/// The `AuthenticationRequiredSASLMessage` struct represents a SASL authentication request. +struct AuthenticationRequiredSASLMessage: Message { + /// Identifies the message as authentication required SASL. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that this message contains a SASL authentication request. + auth_status: i32 = 0x0A, + /// A list of supported SASL authentication methods. + methods: Array, +} + +/// The `AuthenticationSASLContinue` struct represents a SASL challenge. +struct AuthenticationSASLContinue: Message { + /// Identifies the message as authentication SASL continue. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that this message contains a SASL challenge. + auth_status: i32 = 0x0B, + /// Mechanism-specific SASL data. + sasl_data: Array, +} + +/// The `AuthenticationSASLFinal` struct represents the completion of SASL authentication. +struct AuthenticationSASLFinal: Message { + /// Identifies the message as authentication SASL final. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that SASL authentication has completed. + auth_status: i32 = 0x0C, + /// SASL data. + sasl_data: Array, +} + +/// The `Dump` struct represents a dump message from the client. +struct Dump: Message { + /// Identifies the message as dump. + mtype: u8 = '>', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message annotations. + annotations: Array, +} + +/// The `Sync` struct represents a synchronization message from the client. +struct Sync: Message { + /// Identifies the message as sync. + mtype: u8 = 'S', + /// Length of message contents in bytes, including self. + mlen: len, +} + +/// The `Flush` struct represents a flush message from the client. +struct Flush: Message { + /// Identifies the message as flush. + mtype: u8 = 'H', + /// Length of message contents in bytes, including self. + mlen: len, +} + +/// The `Restore` struct represents a restore message from the client. +struct Restore: Message { + /// Identifies the message as restore. + mtype: u8 = '<', + /// Length of message contents in bytes, including self. + mlen: len, + /// Restore attributes. + attributes: Array, + /// Number of parallel jobs for restore. + jobs: i16, + /// Original DumpHeader packet data excluding mtype and message_length. + header_data: Array, +} + +/// The `RestoreBlock` struct represents a restore block message from the client. +struct RestoreBlock: Message { + /// Identifies the message as restore block. + mtype: u8 = '=', + /// Length of message contents in bytes, including self. + mlen: len, + /// Original DumpBlock packet data excluding mtype and message_length. + block_data: Array, +} + +/// The `RestoreEof` struct represents the end of restore message from the client. +struct RestoreEof: Message { + /// Identifies the message as restore EOF. + mtype: u8 = '.', + /// Length of message contents in bytes, including self. + mlen: len, +} + +/// The `Parse` struct represents a parse message from the client. +struct Parse: Message { + /// Identifies the message as parse. + mtype: u8 = 'P', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message annotations. + annotations: Array, + /// A bit mask of allowed capabilities. + allowed_capabilities: u64, + /// A bit mask of query options. + compilation_flags: u64, + /// Implicit LIMIT clause on returned sets. + implicit_limit: u64, + /// Data output format. + output_format: u8, + /// Expected result cardinality. + expected_cardinality: u8, + /// Command text. + command_text: LString, + /// State data descriptor ID. + state_typedesc_id: Uuid, + /// Encoded state data. + state_data: Array, +} + +/// The `Execute` struct represents an execute message from the client. +struct Execute: Message { + /// Identifies the message as execute. + mtype: u8 = 'O', + /// Length of message contents in bytes, including self. + mlen: len, + /// Message annotations. + annotations: Array, + /// A bit mask of allowed capabilities. + allowed_capabilities: u64, + /// A bit mask of query options. + compilation_flags: u64, + /// Implicit LIMIT clause on returned sets. + implicit_limit: u64, + /// Data output format. + output_format: u8, + /// Expected result cardinality. + expected_cardinality: u8, + /// Command text. + command_text: LString, + /// State data descriptor ID. + state_typedesc_id: Uuid, + /// Encoded state data. + state_data: Array, + /// Argument data descriptor ID. + input_typedesc_id: Uuid, + /// Output data descriptor ID. + output_typedesc_id: Uuid, + /// Encoded argument data. + arguments: Array, +} + +/// The `ClientHandshake` struct represents a client handshake message. +struct ClientHandshake: Message { + /// Identifies the message as client handshake. + mtype: u8 = 'V', + /// Length of message contents in bytes, including self. + mlen: len, + /// Requested protocol major version. + major_ver: i16, + /// Requested protocol minor version. + minor_ver: i16, + /// Connection parameters. + params: Array, + /// Requested protocol extensions. + extensions: Array, +} + +/// The `Terminate` struct represents a termination message from the client. +struct Terminate: Message { + /// Identifies the message as terminate. + mtype: u8 = 'X', + /// Length of message contents in bytes, including self. + mlen: len, +} + +/// The `AuthenticationSASLInitialResponse` struct represents the initial SASL response from the client. +struct AuthenticationSASLInitialResponse: Message { + /// Identifies the message as authentication SASL initial response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// Name of the SASL authentication mechanism that the client selected. + method: LString, + /// Mechanism-specific "Initial Response" data. + sasl_data: Array, +} + +/// The `AuthenticationSASLResponse` struct represents a SASL response from the client. +struct AuthenticationSASLResponse: Message { + /// Identifies the message as authentication SASL response. + mtype: u8 = 'r', + /// Length of message contents in bytes, including self. + mlen: len, + /// Mechanism-specific response data. + sasl_data: Array, +} + +/// The `KeyValue` struct represents a key-value pair. +struct KeyValue { + /// Key code (specific to the type of the Message). + code: i16, + /// Value data. + value: Array, +} + +/// The `Annotation` struct represents an annotation. +struct Annotation { + /// Name of the annotation. + name: LString, + /// Value of the annotation (in JSON format). + value: LString, +} + +/// The `DataElement` struct represents a data element. +struct DataElement { + /// Encoded output data. + data: Array, +} + +/// The `DumpTypeInfo` struct represents type information in a dump. +struct DumpTypeInfo { + /// Type name. + type_name: LString, + /// Type class. + type_class: LString, + /// Type ID. + type_id: Uuid, +} + +/// The `DumpObjectDesc` struct represents an object descriptor in a dump. +struct DumpObjectDesc { + /// Object ID. + object_id: Uuid, + /// Description. + description: Array, + /// Dependencies. + dependencies: Array, +} + +/// The `ProtocolExtension` struct represents a protocol extension. +struct ProtocolExtension { + /// Extension name. + name: LString, + /// A set of extension annotations. + annotations: Array, +} + +/// The `ConnectionParam` struct represents a connection parameter. +struct ConnectionParam { + /// Parameter name. + name: LString, + /// Parameter value. + value: LString, +} +); diff --git a/rust/pgrust/src/protocol/gen.rs b/rust/pgrust/src/protocol/gen.rs index 85b88b78226..c44f87956c8 100644 --- a/rust/pgrust/src/protocol/gen.rs +++ b/rust/pgrust/src/protocol/gen.rs @@ -85,17 +85,26 @@ macro_rules! struct_elaborate { struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::protocol::meta::Length), size(fixed=fixed), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); }; // Pattern match on known fixed-sized types and mark them as `size(fixed=fixed)` - (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type([u8; 4])($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type([u8; $len:literal])($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u8)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+1)) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr)fields([type(i16)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+2)) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(i32)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u32)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u64)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + }; + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(Uuid)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; // Fallback for other types - variable sized @@ -164,9 +173,10 @@ macro_rules! protocol { ($( $( #[ $sdoc:meta ] )* struct $name:ident $(: $super:ident)? { $($struct:tt)+ } )+) => { $( paste::paste!( - pub(crate) mod [<$name:lower>] { - #[allow(unused_imports)] - use super::*; + #[allow(unused_imports)] + pub(crate) mod [<__ $name:lower>] { + use super::meta::*; + use $crate::protocol::meta::*; use $crate::protocol::gen::*; struct_elaborate!(protocol_builder(__struct__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); struct_elaborate!(protocol_builder(__meta__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); @@ -180,7 +190,7 @@ macro_rules! protocol { #![allow(unused_imports)] $( paste::paste!( - pub use super::[<$name:lower>]::$name; + pub use super::[<__ $name:lower>]::$name; ); )+ } @@ -188,7 +198,7 @@ macro_rules! protocol { #![allow(unused_imports)] $( paste::paste!( - pub use super::[<$name:lower>]::[<$name Meta>] as $name; + pub use super::[<__ $name:lower>]::[<$name Meta>] as $name; ); )+ @@ -205,7 +215,7 @@ macro_rules! protocol { #![allow(unused_imports)] $( paste::paste!( - pub use super::[<$name:lower>]::[<$name Builder>] as $name; + pub use super::[<__ $name:lower>]::[<$name Builder>] as $name; ); )+ } @@ -213,7 +223,7 @@ macro_rules! protocol { #![allow(unused_imports)] $( paste::paste!( - pub use super::[<$name:lower>]::[<$name Measure>] as $name; + pub use super::[<__ $name:lower>]::[<$name Measure>] as $name; ); )+ } @@ -718,12 +728,24 @@ mod tests { ); } + mod string { + use crate::protocol::meta::LString; + protocol!( + struct HasLString { + s: LString, + } + ); + } + macro_rules! assert_stringify { (($($struct:tt)*), ($($expected:tt)*)) => { struct_elaborate!(assert_stringify(__internal__ ($($expected)*)) => $($struct)*); }; (__internal__ ($($expected:tt)*), $($struct:tt)*) => { - assert_eq!(stringify!($($struct)*), stringify!($($expected)*)); + // We don't want whitespace to impact this comparison + if stringify!($($struct)*).replace(char::is_whitespace, "") != stringify!($($expected)*).replace(char::is_whitespace, "") { + assert_eq!(stringify!($($struct)*), stringify!($($expected)*)); + } }; } @@ -749,7 +771,7 @@ mod tests { { name(b), type (u8), size(fixed = fixed), value(no_value = no_value), docs(concat!("`", stringify! (b), "` field.")), - fixed(fixed_offset = fixed_offset, ((0) + 1)), + fixed(fixed_offset = fixed_offset, ((0) + std::mem::size_of::())), },), })); } @@ -775,13 +797,13 @@ mod tests { { name(l), type (crate::protocol::meta::Length), size(fixed = fixed), value(auto = auto), docs(concat!("`", stringify! (l), "` field.")), - fixed(fixed_offset = fixed_offset, ((0) + 1)), + fixed(fixed_offset = fixed_offset, ((0) + std::mem::size_of::())), }, { name(s), type (ZTString), size(variable = variable), value(no_value = no_value), docs(concat!("`", stringify! (s), "` field.")), - fixed(fixed_offset = fixed_offset, (((0) + 1) + 4)), + fixed(fixed_offset = fixed_offset, (((0) + std::mem::size_of::()) + 4)), }, { name(c), type (i16), size(fixed = fixed), value(no_value = no_value), @@ -792,13 +814,14 @@ mod tests { name(d), type ([u8; 4]), size(fixed = fixed), value(no_value = no_value), docs(concat!("`", stringify! (d), "` field.")), - fixed(no_fixed_offset = no_fixed_offset, ((0) + 2)), + fixed(no_fixed_offset = no_fixed_offset, ((0) + std::mem::size_of::())), }, { name(e), type (ZTArray), size(variable = variable), value(no_value = no_value), docs(concat!("`", stringify! (e), "` field.")), - fixed(no_fixed_offset = no_fixed_offset, (((0) + 2) + 4)), + fixed(no_fixed_offset = no_fixed_offset, + (((0) + std::mem::size_of::()) + std::mem::size_of::<[u8; 4]>())), }, ), })); diff --git a/rust/pgrust/src/protocol/message_group.rs b/rust/pgrust/src/protocol/message_group.rs index 04058a22dde..0f5f0720857 100644 --- a/rust/pgrust/src/protocol/message_group.rs +++ b/rust/pgrust/src/protocol/message_group.rs @@ -1,5 +1,5 @@ macro_rules! message_group { - ($(#[$doc:meta])* $group:ident : $super:ident = [$($message:ty),*]) => { + ($(#[$doc:meta])* $group:ident : $super:ident = [$($message:ident),*]) => { paste::paste!( $(#[$doc])* #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -28,6 +28,14 @@ macro_rules! message_group { )* } } + + pub fn copy_to_buf(&self, writer: &mut $crate::protocol::writer::BufWriter) { + match self { + $( + Self::$message(message) => message.copy_to_buf(writer), + )* + } + } } $( @@ -57,7 +65,7 @@ macro_rules! message_group { impl $group { pub fn identify(buf: &[u8]) -> Option { $( - if <$message as $crate::protocol::Enliven>::WithLifetime::is_buffer(buf) { + if ::WithLifetime::is_buffer(buf) { return Some(Self::$message); } )* @@ -74,7 +82,7 @@ pub(crate) use message_group; /// /// ```rust /// use pgrust::protocol::*; -/// use pgrust::protocol::messages::*; +/// use pgrust::protocol::postgres::data::*; /// /// let buf = [b'?', 0, 0, 0, 4]; /// match_message!(Message::new(&buf), Backend { @@ -90,7 +98,7 @@ pub(crate) use message_group; #[macro_export] macro_rules! __match_message { ($buf:expr, $messages:ty { - $(( $i1:path $(as $i2:ident )?) => $impl:block,)* + $(( $i1:path $(as $i2:ident )?) $(if $cond:expr)? => $impl:block,)* $unknown:ident => $unknown_impl:block $(,)? }) => { 'block: { @@ -98,7 +106,7 @@ macro_rules! __match_message { let res = match __message { Ok(__message) => { $( - if <$i1>::is_buffer(&__message.as_ref()) { + if $($cond &&)? <$i1>::is_buffer(&__message.as_ref()) { match(<$i1>::new(&__message.as_ref())) { Ok(__tmp) => { $(let $i2 = __tmp;)? @@ -130,7 +138,10 @@ pub use __match_message as match_message; #[cfg(test)] mod tests { use super::*; - use crate::protocol::{builder, Message, PasswordMessage}; + use crate::protocol::postgres::{ + builder, + data::{Message, PasswordMessage}, + }; #[test] fn test_match() { diff --git a/rust/pgrust/src/protocol/mod.rs b/rust/pgrust/src/protocol/mod.rs index 57eb9564b4d..d83b8354454 100644 --- a/rust/pgrust/src/protocol/mod.rs +++ b/rust/pgrust/src/protocol/mod.rs @@ -1,41 +1,25 @@ mod arrays; mod buffer; mod datatypes; -pub(crate) mod definition; +pub mod edgedb; mod gen; mod message_group; +pub mod postgres; mod writer; /// Metatypes for the protocol and related arrays/strings. pub mod meta { pub use super::arrays::meta::*; pub use super::datatypes::meta::*; - pub use super::definition::meta::*; -} - -/// Measurement structs. -pub mod measure { - pub use super::definition::measure::*; -} - -/// Builder structs. -pub mod builder { - pub use super::definition::builder::*; -} - -/// Message types collections. -pub mod messages { - pub use super::definition::{Backend, Frontend, Initial}; } #[allow(unused)] pub use arrays::{Array, ArrayIter, ZTArray, ZTArrayIter}; pub use buffer::StructBuffer; #[allow(unused)] -pub use datatypes::{Encoded, Rest, ZTString}; -#[allow(unused)] -pub use definition::data::*; +pub use datatypes::{Encoded, LString, Rest, ZTString}; pub use message_group::match_message; +pub use writer::BufWriter; #[derive(thiserror::Error, Debug, Clone, Copy, PartialEq, Eq)] pub enum ParseError { @@ -179,7 +163,7 @@ pub(crate) use field_access; mod tests { use super::*; use buffer::StructBuffer; - use definition::builder; + use postgres::{builder, data::*, measure, meta}; use rand::Rng; /// We want to ensure that no malformed messages will cause unexpected /// panics, so we try all sorts of combinations of message mutation to @@ -625,4 +609,15 @@ mod tests { fuzz_test::(message); } + + #[test] + fn test_edgedb_sasl() { + use crate::protocol::edgedb::*; + + assert_eq!(builder::AuthenticationRequiredSASLMessage { + methods: &["SCRAM-SHA-256"] + }.to_vec(), vec![82, 0, 0, 0, 29, 0, 0, 0, 10, 0, 0, 0, 1, 0, 0, 0, 13, 83, 67, 82, 65, 77, 45, 83, 72, 65, 45, 50, 53, 54]); + + + } } diff --git a/rust/pgrust/src/protocol/postgres.rs b/rust/pgrust/src/protocol/postgres.rs new file mode 100644 index 00000000000..04bbdc106f5 --- /dev/null +++ b/rust/pgrust/src/protocol/postgres.rs @@ -0,0 +1,739 @@ +use super::gen::protocol; +use super::message_group::message_group; + +message_group!( + /// The `Backend` message group contains messages sent from the backend to the frontend. + Backend: Message = [ + AuthenticationOk, + AuthenticationKerberosV5, + AuthenticationCleartextPassword, + AuthenticationMD5Password, + AuthenticationGSS, + AuthenticationGSSContinue, + AuthenticationSSPI, + AuthenticationSASL, + AuthenticationSASLContinue, + AuthenticationSASLFinal, + BackendKeyData, + BindComplete, + CloseComplete, + CommandComplete, + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, + CopyBothResponse, + DataRow, + EmptyQueryResponse, + ErrorResponse, + FunctionCallResponse, + NegotiateProtocolVersion, + NoData, + NoticeResponse, + NotificationResponse, + ParameterDescription, + ParameterStatus, + ParseComplete, + PortalSuspended, + ReadyForQuery, + RowDescription + ] +); + +message_group!( + /// The `Frontend` message group contains messages sent from the frontend to the backend. + Frontend: Message = [ + Bind, + Close, + CopyData, + CopyDone, + CopyFail, + Describe, + Execute, + Flush, + FunctionCall, + GSSResponse, + Parse, + PasswordMessage, + Query, + SASLInitialResponse, + SASLResponse, + Sync, + Terminate + ] +); + +message_group!( + /// The `Initial` message group contains messages that are sent before the + /// normal message flow. + Initial: InitialMessage = [ + CancelRequest, + GSSENCRequest, + SSLRequest, + StartupMessage + ] +); + +protocol!( + +/// A generic base for all Postgres mtype/mlen-style messages. +struct Message { + /// Identifies the message. + mtype: u8, + /// Length of message contents in bytes, including self. + mlen: len, + /// Message contents. + data: Rest, +} + +/// A generic base for all initial Postgres messages. +struct InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len, + /// The identifier for this initial message. + protocol_version: i32, + /// Message contents. + data: Rest +} + +/// The `AuthenticationMessage` struct is a base for all Postgres authentication messages. +struct AuthenticationMessage: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that the authentication was successful. + status: i32, +} + +/// The `AuthenticationOk` struct represents a message indicating successful authentication. +struct AuthenticationOk: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that the authentication was successful. + status: i32 = 0, +} + +/// The `AuthenticationKerberosV5` struct represents a message indicating that Kerberos V5 authentication is required. +struct AuthenticationKerberosV5: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that Kerberos V5 authentication is required. + status: i32 = 2, +} + +/// The `AuthenticationCleartextPassword` struct represents a message indicating that a cleartext password is required for authentication. +struct AuthenticationCleartextPassword: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that a clear-text password is required. + status: i32 = 3, +} + +/// The `AuthenticationMD5Password` struct represents a message indicating that an MD5-encrypted password is required for authentication. +struct AuthenticationMD5Password: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 12, + /// Specifies that an MD5-encrypted password is required. + status: i32 = 5, + /// The salt to use when encrypting the password. + salt: [u8; 4], +} + +/// The `AuthenticationSCMCredential` struct represents a message indicating that an SCM credential is required for authentication. +struct AuthenticationSCMCredential: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 6, + /// Any data byte, which is ignored. + byte: u8 = 0, +} + +/// The `AuthenticationGSS` struct represents a message indicating that GSSAPI authentication is required. +struct AuthenticationGSS: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that GSSAPI authentication is required. + status: i32 = 7, +} + +/// The `AuthenticationGSSContinue` struct represents a message indicating the continuation of GSSAPI authentication. +struct AuthenticationGSSContinue: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that this message contains GSSAPI or SSPI data. + status: i32 = 8, + /// GSSAPI or SSPI authentication data. + data: Rest, +} + +/// The `AuthenticationSSPI` struct represents a message indicating that SSPI authentication is required. +struct AuthenticationSSPI: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// Specifies that SSPI authentication is required. + status: i32 = 9, +} + +/// The `AuthenticationSASL` struct represents a message indicating that SASL authentication is required. +struct AuthenticationSASL: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that SASL authentication is required. + status: i32 = 10, + /// List of SASL authentication mechanisms, terminated by a zero byte. + mechanisms: ZTArray, +} + +/// The `AuthenticationSASLContinue` struct represents a message containing a SASL challenge during the authentication process. +struct AuthenticationSASLContinue: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that this message contains a SASL challenge. + status: i32 = 11, + /// SASL data, specific to the SASL mechanism being used. + data: Rest, +} + +/// The `AuthenticationSASLFinal` struct represents a message indicating the completion of SASL authentication. +struct AuthenticationSASLFinal: Message { + /// Identifies the message as an authentication request. + mtype: u8 = 'R', + /// Length of message contents in bytes, including self. + mlen: len, + /// Specifies that SASL authentication has completed. + status: i32 = 12, + /// SASL outcome "additional data", specific to the SASL mechanism being used. + data: Rest, +} + +/// The `BackendKeyData` struct represents a message containing the process ID and secret key for this backend. +struct BackendKeyData: Message { + /// Identifies the message as cancellation key data. + mtype: u8 = 'K', + /// Length of message contents in bytes, including self. + mlen: len = 12, + /// The process ID of this backend. + pid: i32, + /// The secret key of this backend. + key: i32, +} + +/// The `Bind` struct represents a message to bind a named portal to a prepared statement. +struct Bind: Message { + /// Identifies the message as a Bind command. + mtype: u8 = 'B', + /// Length of message contents in bytes, including self. + mlen: len, + /// The name of the destination portal. + portal: ZTString, + /// The name of the source prepared statement. + statement: ZTString, + /// The parameter format codes. + format_codes: Array, + /// Array of parameter values and their lengths. + values: Array, + /// The result-column format codes. + result_format_codes: Array, +} + +/// The `BindComplete` struct represents a message indicating that a Bind operation was successful. +struct BindComplete: Message { + /// Identifies the message as a Bind-complete indicator. + mtype: u8 = '2', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `CancelRequest` struct represents a message to request the cancellation of a query. +struct CancelRequest: InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len = 16, + /// The cancel request code. + code: i32 = 80877102, + /// The process ID of the target backend. + pid: i32, + /// The secret key for the target backend. + key: i32, +} + +/// The `Close` struct represents a message to close a prepared statement or portal. +struct Close: Message { + /// Identifies the message as a Close command. + mtype: u8 = 'C', + /// Length of message contents in bytes, including self. + mlen: len, + /// 'xS' to close a prepared statement; 'P' to close a portal. + ctype: u8, + /// The name of the prepared statement or portal to close. + name: ZTString, +} + +/// The `CloseComplete` struct represents a message indicating that a Close operation was successful. +struct CloseComplete: Message { + /// Identifies the message as a Close-complete indicator. + mtype: u8 = '3', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `CommandComplete` struct represents a message indicating the successful completion of a command. +struct CommandComplete: Message { + /// Identifies the message as a command-completed response. + mtype: u8 = 'C', + /// Length of message contents in bytes, including self. + mlen: len, + /// The command tag. + tag: ZTString, +} + +/// The `CopyData` struct represents a message containing data for a copy operation. +struct CopyData: Message { + /// Identifies the message as COPY data. + mtype: u8 = 'd', + /// Length of message contents in bytes, including self. + mlen: len, + /// Data that forms part of a COPY data stream. + data: Rest, +} + +/// The `CopyDone` struct represents a message indicating that a copy operation is complete. +struct CopyDone: Message { + /// Identifies the message as a COPY-complete indicator. + mtype: u8 = 'c', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `CopyFail` struct represents a message indicating that a copy operation has failed. +struct CopyFail: Message { + /// Identifies the message as a COPY-failure indicator. + mtype: u8 = 'f', + /// Length of message contents in bytes, including self. + mlen: len, + /// An error message to report as the cause of failure. + error_msg: ZTString, +} + +/// The `CopyInResponse` struct represents a message indicating that the server is ready to receive data for a copy-in operation. +struct CopyInResponse: Message { + /// Identifies the message as a Start Copy In response. + mtype: u8 = 'G', + /// Length of message contents in bytes, including self. + mlen: len, + /// 0 for textual, 1 for binary. + format: u8, + /// The format codes for each column. + format_codes: Array, +} + +/// The `CopyOutResponse` struct represents a message indicating that the server is ready to send data for a copy-out operation. +struct CopyOutResponse: Message { + /// Identifies the message as a Start Copy Out response. + mtype: u8 = 'H', + /// Length of message contents in bytes, including self. + mlen: len, + /// 0 for textual, 1 for binary. + format: u8, + /// The format codes for each column. + format_codes: Array, +} + +/// The `CopyBothResponse` is used only for Streaming Replication. +struct CopyBothResponse: Message { + /// Identifies the message as a Start Copy Both response. + mtype: u8 = 'W', + /// Length of message contents in bytes, including self. + mlen: len, + /// 0 for textual, 1 for binary. + format: u8, + /// The format codes for each column. + format_codes: Array, +} + +/// The `DataRow` struct represents a message containing a row of data. +struct DataRow: Message { + /// Identifies the message as a data row. + mtype: u8 = 'D', + /// Length of message contents in bytes, including self. + mlen: len, + /// Array of column values and their lengths. + values: Array, +} + +/// The `Describe` struct represents a message to describe a prepared statement or portal. +struct Describe: Message { + /// Identifies the message as a Describe command. + mtype: u8 = 'D', + /// Length of message contents in bytes, including self. + mlen: len, + /// 'S' to describe a prepared statement; 'P' to describe a portal. + dtype: u8, + /// The name of the prepared statement or portal. + name: ZTString, +} + +/// The `EmptyQueryResponse` struct represents a message indicating that an empty query string was recognized. +struct EmptyQueryResponse: Message { + /// Identifies the message as a response to an empty query String. + mtype: u8 = 'I', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `ErrorResponse` struct represents a message indicating that an error has occurred. +struct ErrorResponse: Message { + /// Identifies the message as an error. + mtype: u8 = 'E', + /// Length of message contents in bytes, including self. + mlen: len, + /// Array of error fields and their values. + fields: ZTArray, +} + +/// The `ErrorField` struct represents a single error message within an `ErrorResponse`. +struct ErrorField { + /// A code identifying the field type. + etype: u8, + /// The field value. + value: ZTString, +} + +/// The `Execute` struct represents a message to execute a prepared statement or portal. +struct Execute: Message { + /// Identifies the message as an Execute command. + mtype: u8 = 'E', + /// Length of message contents in bytes, including self. + mlen: len, + /// The name of the portal to execute. + portal: ZTString, + /// Maximum number of rows to return. + max_rows: i32, +} + +/// The `Flush` struct represents a message to flush the backend's output buffer. +struct Flush: Message { + /// Identifies the message as a Flush command. + mtype: u8 = 'H', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `FunctionCall` struct represents a message to call a function. +struct FunctionCall: Message { + /// Identifies the message as a function call. + mtype: u8 = 'F', + /// Length of message contents in bytes, including self. + mlen: len, + /// OID of the function to execute. + function_id: i32, + /// The parameter format codes. + format_codes: Array, + /// Array of args and their lengths. + args: Array, + /// The format code for the result. + result_format_code: i16, +} + +/// The `FunctionCallResponse` struct represents a message containing the result of a function call. +struct FunctionCallResponse: Message { + /// Identifies the message as a function-call response. + mtype: u8 = 'V', + /// Length of message contents in bytes, including self. + mlen: len, + /// The function result value. + result: Encoded, +} + +/// The `GSSENCRequest` struct represents a message requesting GSSAPI encryption. +struct GSSENCRequest: InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// The GSSAPI Encryption request code. + gssenc_request_code: i32 = 80877104, +} + +/// The `GSSResponse` struct represents a message containing a GSSAPI or SSPI response. +struct GSSResponse: Message { + /// Identifies the message as a GSSAPI or SSPI response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// GSSAPI or SSPI authentication data. + data: Rest, +} + +/// The `NegotiateProtocolVersion` struct represents a message requesting protocol version negotiation. +struct NegotiateProtocolVersion: Message { + /// Identifies the message as a protocol version negotiation request. + mtype: u8 = 'v', + /// Length of message contents in bytes, including self. + mlen: len, + /// Newest minor protocol version supported by the server. + minor_version: i32, + /// List of protocol options not recognized. + options: Array, +} + +/// The `NoData` struct represents a message indicating that there is no data to return. +struct NoData: Message { + /// Identifies the message as a No Data indicator. + mtype: u8 = 'n', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `NoticeResponse` struct represents a message containing a notice. +struct NoticeResponse: Message { + /// Identifies the message as a notice. + mtype: u8 = 'N', + /// Length of message contents in bytes, including self. + mlen: len, + /// Array of notice fields and their values. + fields: ZTArray, +} + +/// The `NoticeField` struct represents a single error message within an `NoticeResponse`. +struct NoticeField: Message { + /// A code identifying the field type. + ntype: u8, + /// The field value. + value: ZTString, +} + +/// The `NotificationResponse` struct represents a message containing a notification from the backend. +struct NotificationResponse: Message { + /// Identifies the message as a notification. + mtype: u8 = 'A', + /// Length of message contents in bytes, including self. + mlen: len, + /// The process ID of the notifying backend. + pid: i32, + /// The name of the notification channel. + channel: ZTString, + /// The notification payload. + payload: ZTString, +} + +/// The `ParameterDescription` struct represents a message describing the parameters needed by a prepared statement. +struct ParameterDescription: Message { + /// Identifies the message as a parameter description. + mtype: u8 = 't', + /// Length of message contents in bytes, including self. + mlen: len, + /// OIDs of the parameter data types. + param_types: Array, +} + +/// The `ParameterStatus` struct represents a message containing the current status of a parameter. +struct ParameterStatus: Message { + /// Identifies the message as a runtime parameter status report. + mtype: u8 = 'S', + /// Length of message contents in bytes, including self. + mlen: len, + /// The name of the parameter. + name: ZTString, + /// The current value of the parameter. + value: ZTString, +} + +/// The `Parse` struct represents a message to parse a query string. +struct Parse: Message { + /// Identifies the message as a Parse command. + mtype: u8 = 'P', + /// Length of message contents in bytes, including self. + mlen: len, + /// The name of the destination prepared statement. + statement: ZTString, + /// The query String to be parsed. + query: ZTString, + /// OIDs of the parameter data types. + param_types: Array, +} + +/// The `ParseComplete` struct represents a message indicating that a Parse operation was successful. +struct ParseComplete: Message { + /// Identifies the message as a Parse-complete indicator. + mtype: u8 = '1', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `PasswordMessage` struct represents a message containing a password. +struct PasswordMessage: Message { + /// Identifies the message as a password response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// The password (encrypted or plaintext, depending on context). + password: ZTString, +} + +/// The `PortalSuspended` struct represents a message indicating that a portal has been suspended. +struct PortalSuspended: Message { + /// Identifies the message as a portal-suspended indicator. + mtype: u8 = 's', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `Query` struct represents a message to execute a simple query. +struct Query: Message { + /// Identifies the message as a simple query command. + mtype: u8 = 'Q', + /// Length of message contents in bytes, including self. + mlen: len, + /// The query String to be executed. + query: ZTString, +} + +/// The `ReadyForQuery` struct represents a message indicating that the backend is ready for a new query. +struct ReadyForQuery: Message { + /// Identifies the message as a ready-for-query indicator. + mtype: u8 = 'Z', + /// Length of message contents in bytes, including self. + mlen: len = 5, + /// Current transaction status indicator. + status: u8, +} + +/// The `RowDescription` struct represents a message describing the rows that will be returned by a query. +struct RowDescription: Message { + /// Identifies the message as a row description. + mtype: u8 = 'T', + /// Length of message contents in bytes, including self. + mlen: len, + /// Array of field descriptions. + fields: Array, +} + +/// The `RowField` struct represents a row within the `RowDescription` message. +struct RowField { + /// The field name + name: ZTString, + /// The table ID (OID) of the table the column is from, or 0 if not a column reference + table_oid: i32, + /// The attribute number of the column, or 0 if not a column reference + column_attr_number: i16, + /// The object ID of the field's data type + data_type_oid: i32, + /// The data type size (negative if variable size) + data_type_size: i16, + /// The type modifier + type_modifier: i32, + /// The format code being used for the field (0 for text, 1 for binary) + format_code: i16, +} + +/// The `SASLInitialResponse` struct represents a message containing a SASL initial response. +struct SASLInitialResponse: Message { + /// Identifies the message as a SASL initial response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// Name of the SASL authentication mechanism. + mechanism: ZTString, + /// SASL initial response data. + response: Array, +} + +/// The `SASLResponse` struct represents a message containing a SASL response. +struct SASLResponse: Message { + /// Identifies the message as a SASL response. + mtype: u8 = 'p', + /// Length of message contents in bytes, including self. + mlen: len, + /// SASL response data. + response: Rest, +} + +/// The `SSLRequest` struct represents a message requesting SSL encryption. +struct SSLRequest: InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len = 8, + /// The SSL request code. + code: i32 = 80877103, +} + +struct SSLResponse { + /// Specifies if SSL was accepted or rejected. + code: u8, +} + +/// The `StartupMessage` struct represents a message to initiate a connection. +struct StartupMessage: InitialMessage { + /// Length of message contents in bytes, including self. + mlen: len, + /// The protocol version number. + protocol: i32 = 196608, + /// List of parameter name-value pairs, terminated by a zero byte. + params: ZTArray, +} + +/// The `StartupMessage` struct represents a name/value pair within the `StartupMessage` message. +struct StartupNameValue { + /// The parameter name. + name: ZTString, + /// The parameter value. + value: ZTString, +} + +/// The `Sync` struct represents a message to synchronize the frontend and backend. +struct Sync: Message { + /// Identifies the message as a Sync command. + mtype: u8 = 'S', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} + +/// The `Terminate` struct represents a message to terminate a connection. +struct Terminate: Message { + /// Identifies the message as a Terminate command. + mtype: u8 = 'X', + /// Length of message contents in bytes, including self. + mlen: len = 4, +} +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_all() { + let message = meta::Message::default(); + let initial_message = meta::InitialMessage::default(); + + for meta in meta::ALL { + eprintln!("{meta:#?}"); + if **meta != message && **meta != initial_message { + if meta.field("mtype").is_some() && meta.field("mlen").is_some() { + // If a message has mtype and mlen, it should subclass Message + assert_eq!(*meta.parent().unwrap(), message); + } else if meta.field("mlen").is_some() { + // If a message has mlen only, it should subclass InitialMessage + assert_eq!(*meta.parent().unwrap(), initial_message); + } + } + } + } +} diff --git a/rust/pgrust/src/python.rs b/rust/pgrust/src/python.rs index 749fe61fd04..889380dd6c9 100644 --- a/rust/pgrust/src/python.rs +++ b/rust/pgrust/src/python.rs @@ -11,7 +11,7 @@ use crate::{ }, ConnectionSslRequirement, }, - protocol::{meta, SSLResponse, StructBuffer}, + protocol::{postgres::{data::SSLResponse, meta, FrontendBuilder, InitialBuilder}, StructBuffer}, }; use pyo3::{ buffer::PyBuffer, @@ -90,7 +90,6 @@ impl PyConnectionParams { } #[getter] - #[allow(clippy::type_complexity)] pub fn host_candidates( &self, py: Python, @@ -355,7 +354,7 @@ impl PyConnectionState { let buffer = PyBuffer::::get(data)?; if self.inner.read_ssl_response() { // SSL responses are always one character - let response = [buffer.as_slice(py).unwrap().first().unwrap().get()]; + let response = [buffer.as_slice(py).unwrap().get(0).unwrap().get()]; let response = SSLResponse::new(&response)?; self.inner .drive(ConnectionDrive::SslResponse(response), &mut self.update)?; @@ -406,7 +405,7 @@ struct PyConnectionStateUpdate { impl ConnectionStateSend for PyConnectionStateUpdate { fn send_initial( &mut self, - message: crate::protocol::definition::InitialBuilder, + message: InitialBuilder, ) -> Result<(), std::io::Error> { Python::with_gil(|py| { let bytes = PyByteArray::new(py, &message.to_vec()); @@ -420,7 +419,7 @@ impl ConnectionStateSend for PyConnectionStateUpdate { fn send( &mut self, - message: crate::protocol::definition::FrontendBuilder, + message: FrontendBuilder, ) -> Result<(), std::io::Error> { Python::with_gil(|py| { let bytes = PyBytes::new(py, &message.to_vec()); @@ -477,7 +476,7 @@ impl ConnectionStateUpdate for PyConnectionStateUpdate { }); } - fn auth(&mut self, auth: crate::handshake::AuthType) { + fn auth(&mut self, auth: crate::auth::AuthType) { Python::with_gil(|py| { if let Err(e) = self.py_update.call_method1(py, "auth", (auth as u8,)) { eprintln!("Error in auth: {:?}", e); diff --git a/rust/pgrust/tests/real_postgres.rs b/rust/pgrust/tests/real_postgres.rs index b3dbb1fc810..a263ee3fc67 100644 --- a/rust/pgrust/tests/real_postgres.rs +++ b/rust/pgrust/tests/real_postgres.rs @@ -1,9 +1,10 @@ // Constants use openssl::ssl::{Ssl, SslContext, SslMethod}; +use pgrust::auth::AuthType; use pgrust::connection::dsn::{Host, HostType}; use pgrust::connection::{connect_raw_ssl, ConnectionError, Credentials, ResolvedTarget}; use pgrust::errors::PgServerError; -use pgrust::handshake::{AuthType, ConnectionSslRequirement}; +use pgrust::handshake::ConnectionSslRequirement; use rstest::rstest; use std::io::{BufRead, BufReader, Write}; use std::net::{Ipv4Addr, SocketAddr, TcpListener}; @@ -192,6 +193,7 @@ fn run_postgres( let content = std::fs::read_to_string(&pg_hba_path)?; let modified_content = content .lines() + .filter(|line| !line.starts_with("#") && !line.is_empty()) .map(|line| { if line.trim_start().starts_with("host") { line.replacen("host", "hostssl", 1) @@ -201,7 +203,7 @@ fn run_postgres( }) .collect::>() .join("\n"); - eprintln!("pg_hba.conf:\n{modified_content}"); + eprintln!("pg_hba.conf:\n==========\n{modified_content}\n=========="); std::fs::write(&pg_hba_path, modified_content)?; command.arg("-l"); From ea0c01ff16c1c8f81b6ed4e909eeb4302c519f4e Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 26 Nov 2024 11:26:04 -0700 Subject: [PATCH 2/6] Dep cleanup --- Cargo.lock | 30 ------------------------------ rust/pgrust/Cargo.toml | 14 ++++---------- rust/pgrust/src/protocol/mod.rs | 15 ++++++++++----- rust/pgrust/src/python.rs | 15 ++++++--------- 4 files changed, 20 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a04671812a8..57367622a18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -945,12 +945,6 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - [[package]] name = "hex-literal" version = "0.4.1" @@ -1552,14 +1546,11 @@ dependencies = [ "clap", "clap_derive", "constant_time_eq", - "consume_on_drop", "derive_more", "futures", - "hex", "hex-literal", "hexdump", "hmac", - "itertools 0.13.0", "libc", "lru", "md5", @@ -1573,14 +1564,10 @@ dependencies = [ "rstest", "scopeguard", "serde", - "serde-pickle", "serde_derive", "sha2", - "smart-default", "socket2", "statrs", - "stringprep", - "strum", "tempfile", "test-log", "thiserror", @@ -2331,17 +2318,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "stringprep" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" -dependencies = [ - "unicode-bidi", - "unicode-normalization", - "unicode-properties", -] - [[package]] name = "strsim" version = "0.11.1" @@ -2692,12 +2668,6 @@ dependencies = [ "tinyvec", ] -[[package]] -name = "unicode-properties" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ea75f83c0137a9b98608359a5f1af8144876eb67bcb1ce837368e906a9f524" - [[package]] name = "unicode-segmentation" version = "1.11.0" diff --git a/rust/pgrust/Cargo.toml b/rust/pgrust/Cargo.toml index a668bdda775..535341ac1f6 100644 --- a/rust/pgrust/Cargo.toml +++ b/rust/pgrust/Cargo.toml @@ -15,32 +15,23 @@ optimizer = [] [dependencies] pyo3.workspace = true tokio.workspace = true +tracing.workspace = true futures = "0" -scopeguard = "1" -itertools = "0" thiserror = "1" -tracing = "0" -tracing-subscriber = "0" -strum = { version = "0.26", features = ["derive"] } -consume_on_drop = "0" -smart-default = "0" openssl = { version = "0.10.66", features = ["v111"] } tokio-openssl = "0.6.4" paste = "1" unicode-normalization = "0.1.23" -stringprep = "0.1.5" hmac = "0.12" base64 = "0.22" sha2 = "0.10" -hex = "0.4.3" md5 = "0.7.0" rand = "0" hexdump = "0" url = "2" serde = "1" serde_derive = "1" -serde-pickle = "1" percent-encoding = "2" roaring = "0.10.6" constant_time_eq = "0.3" @@ -51,6 +42,9 @@ version = "1.0.0-beta.6" features = ["full"] [dev-dependencies] +tracing-subscriber.workspace = true +scopeguard = "1" + pretty_assertions = "1.2.0" test-log = { version = "0", features = ["trace"] } anyhow = "1" diff --git a/rust/pgrust/src/protocol/mod.rs b/rust/pgrust/src/protocol/mod.rs index d83b8354454..aead9728e4e 100644 --- a/rust/pgrust/src/protocol/mod.rs +++ b/rust/pgrust/src/protocol/mod.rs @@ -614,10 +614,15 @@ mod tests { fn test_edgedb_sasl() { use crate::protocol::edgedb::*; - assert_eq!(builder::AuthenticationRequiredSASLMessage { - methods: &["SCRAM-SHA-256"] - }.to_vec(), vec![82, 0, 0, 0, 29, 0, 0, 0, 10, 0, 0, 0, 1, 0, 0, 0, 13, 83, 67, 82, 65, 77, 45, 83, 72, 65, 45, 50, 53, 54]); - - + assert_eq!( + builder::AuthenticationRequiredSASLMessage { + methods: &["SCRAM-SHA-256"] + } + .to_vec(), + vec![ + 82, 0, 0, 0, 29, 0, 0, 0, 10, 0, 0, 0, 1, 0, 0, 0, 13, 83, 67, 82, 65, 77, 45, 83, + 72, 65, 45, 50, 53, 54 + ] + ); } } diff --git a/rust/pgrust/src/python.rs b/rust/pgrust/src/python.rs index 889380dd6c9..dd7d483bb20 100644 --- a/rust/pgrust/src/python.rs +++ b/rust/pgrust/src/python.rs @@ -11,7 +11,10 @@ use crate::{ }, ConnectionSslRequirement, }, - protocol::{postgres::{data::SSLResponse, meta, FrontendBuilder, InitialBuilder}, StructBuffer}, + protocol::{ + postgres::{data::SSLResponse, meta, FrontendBuilder, InitialBuilder}, + StructBuffer, + }, }; use pyo3::{ buffer::PyBuffer, @@ -403,10 +406,7 @@ struct PyConnectionStateUpdate { } impl ConnectionStateSend for PyConnectionStateUpdate { - fn send_initial( - &mut self, - message: InitialBuilder, - ) -> Result<(), std::io::Error> { + fn send_initial(&mut self, message: InitialBuilder) -> Result<(), std::io::Error> { Python::with_gil(|py| { let bytes = PyByteArray::new(py, &message.to_vec()); if let Err(e) = self.py_update.call_method1(py, "send", (bytes,)) { @@ -417,10 +417,7 @@ impl ConnectionStateSend for PyConnectionStateUpdate { Ok(()) } - fn send( - &mut self, - message: FrontendBuilder, - ) -> Result<(), std::io::Error> { + fn send(&mut self, message: FrontendBuilder) -> Result<(), std::io::Error> { Python::with_gil(|py| { let bytes = PyBytes::new(py, &message.to_vec()); if let Err(e) = self.py_update.call_method1(py, "send", (bytes,)) { From 5ae2a2dd278efeb8819e3d9853e5619d788f9ec7 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 26 Nov 2024 11:54:38 -0700 Subject: [PATCH 3/6] . --- rust/pgrust/src/handshake/edgedb_server.rs | 36 +++++++++---------- rust/pgrust/src/handshake/server_auth.rs | 8 ++--- .../src/handshake/server_state_machine.rs | 1 + rust/pgrust/src/protocol/datatypes.rs | 2 +- rust/pgrust/src/python.rs | 3 +- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/rust/pgrust/src/handshake/edgedb_server.rs b/rust/pgrust/src/handshake/edgedb_server.rs index 2c325129aa9..d1c236c5d4a 100644 --- a/rust/pgrust/src/handshake/edgedb_server.rs +++ b/rust/pgrust/src/handshake/edgedb_server.rs @@ -32,10 +32,11 @@ pub enum ConnectionDrive<'a> { pub trait ConnectionStateSend { fn send(&mut self, message: EdgeDBBackendBuilder) -> Result<(), std::io::Error>; - fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error>; + fn auth(&mut self, user: String, database: String, branch: String) -> Result<(), std::io::Error>; fn params(&mut self) -> Result<(), std::io::Error>; } +#[allow(unused)] pub trait ConnectionStateUpdate: ConnectionStateSend { fn parameter(&mut self, name: &str, value: &str) {} fn state_changed(&mut self, state: ConnectionStateType) {} @@ -45,7 +46,7 @@ pub trait ConnectionStateUpdate: ConnectionStateSend { #[derive(Debug)] pub enum ConnectionEvent<'a> { Send(EdgeDBBackendBuilder<'a>), - Auth(String, String), + Auth(String, String, String), Params, Parameter(&'a str, &'a str), StateChanged(ConnectionStateType), @@ -60,8 +61,8 @@ where self(ConnectionEvent::Send(message)) } - fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error> { - self(ConnectionEvent::Auth(user, database)) + fn auth(&mut self, user: String, database: String, branch: String) -> Result<(), std::io::Error> { + self(ConnectionEvent::Auth(user, database, branch)) } fn params(&mut self) -> Result<(), std::io::Error> { @@ -118,8 +119,10 @@ const AUTH_ERROR: ServerError = ServerError::Protocol(EdbError::AuthenticationEr const PROTOCOL_VERSION_ERROR: ServerError = ServerError::Protocol(EdbError::UnsupportedProtocolVersionError); -#[derive(Debug)] +#[derive(Debug, Default)] +#[allow(clippy::large_enum_variant)] // Auth is much larger enum ServerStateImpl { + #[default] Initial, AuthInfo(String), Authenticating(ServerAuth), @@ -134,13 +137,6 @@ pub struct ServerState { } impl ServerState { - pub fn new() -> Self { - Self { - state: ServerStateImpl::Initial, - buffer: Default::default(), - } - } - pub fn is_ready(&self) -> bool { matches!(self.state, ServerStateImpl::Ready) } @@ -235,21 +231,21 @@ impl ServerStateImpl { update.parameter(param.name().to_str()?, param.value().to_str()?); } if user.is_empty() { - return Err(AUTH_ERROR.into()); + return Err(AUTH_ERROR); } if database.is_empty() { database = user.clone(); } *self = AuthInfo(user.clone()); - update.auth(user, database)?; + update.auth(user, database, branch)?; }, unknown => { log_unknown_message(unknown, "Initial")?; } }); } - (AuthInfo(_), ConnectionDrive::AuthInfo(auth_type, credential_data)) => { - let mut auth = ServerAuth::new(String::new(), auth_type, credential_data); + (AuthInfo(username), ConnectionDrive::AuthInfo(auth_type, credential_data)) => { + let mut auth = ServerAuth::new(username.clone(), auth_type, credential_data); match auth.drive(ServerAuthDrive::Initial) { ServerAuthResponse::Initial(AuthType::ScramSha256, _) => { update.send(EdgeDBBackendBuilder::AuthenticationRequiredSASLMessage( @@ -363,7 +359,7 @@ fn send_error( update.server_error(&code); update.send(EdgeDBBackendBuilder::ErrorResponse( builder::ErrorResponse { - severity: 0x78, + severity: ErrorSeverity::Error as _, error_code: code as i32, message, attributes: &[], @@ -372,7 +368,7 @@ fn send_error( } enum ErrorSeverity { - ERROR = 0x78, - FATAL = 0xc8, - PANIC = 0xff, + Error = 0x78, + Fatal = 0xc8, + Panic = 0xff, } diff --git a/rust/pgrust/src/handshake/server_auth.rs b/rust/pgrust/src/handshake/server_auth.rs index 2d3da6746da..7a6c0b9a01c 100644 --- a/rust/pgrust/src/handshake/server_auth.rs +++ b/rust/pgrust/src/handshake/server_auth.rs @@ -28,7 +28,7 @@ enum ServerAuthState { Initial, Password(CredentialData), MD5([u8; 4], CredentialData), - SASL(ServerTransaction, StoredKey), + Sasl(ServerTransaction, StoredKey), } #[derive(Debug)] @@ -58,7 +58,7 @@ impl ServerAuth { pub fn is_initial_message(&self) -> bool { match &self.state { ServerAuthState::Initial => false, - ServerAuthState::SASL(tx, _) => tx.initial(), + ServerAuthState::Sasl(tx, _) => tx.initial(), _ => true, } } @@ -114,7 +114,7 @@ impl ServerAuth { } } ( - ServerAuthState::SASL(tx, data), + ServerAuthState::Sasl(tx, data), ServerAuthDrive::Message(AuthType::ScramSha256, input), ) => { let initial = tx.initial(); @@ -171,7 +171,7 @@ impl ServerAuth { } }; let tx = ServerTransaction::default(); - self.state = ServerAuthState::SASL(tx, scram); + self.state = ServerAuthState::Sasl(tx, scram); ServerAuthResponse::Initial(AuthType::ScramSha256, Vec::new()) } } diff --git a/rust/pgrust/src/handshake/server_state_machine.rs b/rust/pgrust/src/handshake/server_state_machine.rs index 129286bb877..03549eb4e48 100644 --- a/rust/pgrust/src/handshake/server_state_machine.rs +++ b/rust/pgrust/src/handshake/server_state_machine.rs @@ -130,6 +130,7 @@ where } #[derive(Debug)] +#[allow(clippy::large_enum_variant)] enum ServerStateImpl { /// Initial state, enum indicates whether SSL is required (or None if enabled) Initial(Option), diff --git a/rust/pgrust/src/protocol/datatypes.rs b/rust/pgrust/src/protocol/datatypes.rs index adb6cae2799..99433de3ce4 100644 --- a/rust/pgrust/src/protocol/datatypes.rs +++ b/rust/pgrust/src/protocol/datatypes.rs @@ -373,7 +373,7 @@ impl FieldAccess { #[inline(always)] pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &Uuid) { - buf.write(&value.as_bytes().as_slice()); + buf.write(value.as_bytes().as_slice()); } } diff --git a/rust/pgrust/src/python.rs b/rust/pgrust/src/python.rs index dd7d483bb20..a9aef3509d5 100644 --- a/rust/pgrust/src/python.rs +++ b/rust/pgrust/src/python.rs @@ -93,6 +93,7 @@ impl PyConnectionParams { } #[getter] + #[allow(clippy::type_complexity)] pub fn host_candidates( &self, py: Python, @@ -357,7 +358,7 @@ impl PyConnectionState { let buffer = PyBuffer::::get(data)?; if self.inner.read_ssl_response() { // SSL responses are always one character - let response = [buffer.as_slice(py).unwrap().get(0).unwrap().get()]; + let response = [buffer.as_slice(py).unwrap().first().unwrap().get()]; let response = SSLResponse::new(&response)?; self.inner .drive(ConnectionDrive::SslResponse(response), &mut self.update)?; From 961bb32f3a4d58ebb626e98533042819da6c7fa1 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 26 Nov 2024 12:07:13 -0700 Subject: [PATCH 4/6] . --- rust/pgrust/src/handshake/server_state_machine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/pgrust/src/handshake/server_state_machine.rs b/rust/pgrust/src/handshake/server_state_machine.rs index 03549eb4e48..f8f3c2a506b 100644 --- a/rust/pgrust/src/handshake/server_state_machine.rs +++ b/rust/pgrust/src/handshake/server_state_machine.rs @@ -130,7 +130,7 @@ where } #[derive(Debug)] -#[allow(clippy::large_enum_variant)] +#[allow(clippy::large_enum_variant)] // Auth is much larger enum ServerStateImpl { /// Initial state, enum indicates whether SSL is required (or None if enabled) Initial(Option), From bebc98946586366942783a5431b592353432bd08 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 26 Nov 2024 12:36:27 -0700 Subject: [PATCH 5/6] . --- Cargo.lock | 75 +++++++++++++------ Cargo.toml | 3 + rust/auth/Cargo.toml | 30 ++++++++ rust/auth/README.md | 8 ++ rust/auth/src/handshake/mod.rs | 3 + .../src/handshake/server_auth.rs | 6 +- .../src/auth/mod.rs => auth/src/lib.rs} | 21 ++---- rust/{pgrust/src/auth => auth/src}/md5.rs | 3 +- rust/{pgrust/src/auth => auth/src}/scram.rs | 3 +- .../src/auth => auth/src}/stringprep.rs | 6 +- .../src/auth => auth/src}/stringprep_table.rs | 0 .../src}/stringprep_table_prep.py | 0 rust/http/Cargo.toml | 4 - rust/pgrust/Cargo.toml | 12 +-- rust/pgrust/src/connection/mod.rs | 3 +- rust/pgrust/src/connection/raw_conn.rs | 6 +- .../src/handshake/client_state_machine.rs | 15 ++-- rust/pgrust/src/handshake/edgedb_server.rs | 22 ++++-- rust/pgrust/src/handshake/mod.rs | 3 +- .../src/handshake/server_state_machine.rs | 11 ++- rust/pgrust/src/lib.rs | 1 - rust/pgrust/src/python.rs | 2 +- rust/pgrust/tests/real_postgres.rs | 2 +- 23 files changed, 152 insertions(+), 87 deletions(-) create mode 100644 rust/auth/Cargo.toml create mode 100644 rust/auth/README.md create mode 100644 rust/auth/src/handshake/mod.rs rename rust/{pgrust => auth}/src/handshake/server_auth.rs (98%) rename rust/{pgrust/src/auth/mod.rs => auth/src/lib.rs} (82%) rename rust/{pgrust/src/auth => auth/src}/md5.rs (97%) rename rust/{pgrust/src/auth => auth/src}/scram.rs (99%) rename rust/{pgrust/src/auth => auth/src}/stringprep.rs (98%) rename rust/{pgrust/src/auth => auth/src}/stringprep_table.rs (100%) rename rust/{pgrust/src/auth => auth/src}/stringprep_table_prep.py (100%) diff --git a/Cargo.lock b/Cargo.lock index 57367622a18..7a3d36cb912 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -367,7 +367,7 @@ dependencies = [ "statrs", "strum", "test-log", - "thiserror", + "thiserror 1.0.63", "tokio", "tracing", ] @@ -540,7 +540,7 @@ dependencies = [ "combine", "num-bigint 0.2.6", "num-traits", - "thiserror", + "thiserror 1.0.63", ] [[package]] @@ -584,7 +584,7 @@ dependencies = [ "serde_json", "sha2", "snafu", - "thiserror", + "thiserror 1.0.63", "unicode-width", ] @@ -840,6 +840,26 @@ dependencies = [ "slab", ] +[[package]] +name = "gel_auth" +version = "0.1.0" +dependencies = [ + "base64", + "constant_time_eq", + "derive_more", + "hex-literal", + "hmac", + "md5", + "pretty_assertions", + "rand", + "roaring", + "rstest", + "sha2", + "thiserror 2.0.3", + "tracing", + "unicode-normalization", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -901,7 +921,7 @@ dependencies = [ "num-traits", "pretty_assertions", "pyo3", - "thiserror", + "thiserror 1.0.63", ] [[package]] @@ -973,7 +993,6 @@ dependencies = [ name = "http" version = "0.1.0" dependencies = [ - "derive_more", "eventsource-stream", "futures", "pyo3", @@ -1540,42 +1559,32 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" name = "pgrust" version = "0.1.0" dependencies = [ - "anyhow", "base64", - "byteorder", "clap", "clap_derive", - "constant_time_eq", "derive_more", "futures", - "hex-literal", + "gel_auth", "hexdump", - "hmac", "libc", - "lru", - "md5", "openssl", "paste", "percent-encoding", "pretty_assertions", "pyo3", "rand", - "roaring", "rstest", "scopeguard", "serde", "serde_derive", - "sha2", "socket2", - "statrs", "tempfile", "test-log", - "thiserror", + "thiserror 1.0.63", "tokio", "tokio-openssl", "tracing", "tracing-subscriber", - "unicode-normalization", "url", "uuid", ] @@ -1676,9 +1685,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -1754,7 +1763,7 @@ dependencies = [ "nix", "pyo3", "scopeguard", - "thiserror", + "thiserror 1.0.63", "tokio", "tracing", "tracing-subscriber", @@ -2354,9 +2363,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.79" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", @@ -2440,7 +2449,16 @@ version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.63", +] + +[[package]] +name = "thiserror" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +dependencies = [ + "thiserror-impl 2.0.3", ] [[package]] @@ -2454,6 +2472,17 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror-impl" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.8" diff --git a/Cargo.toml b/Cargo.toml index d9750b71ba2..a7ec25a43fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "edb/edgeql-parser/edgeql-parser-python", "edb/graphql-rewrite", "edb/server/_rust_native", + "rust/auth", "rust/conn_pool", "rust/pgrust", "rust/http", @@ -17,6 +18,8 @@ pyo3 = { version = "0.23.1", features = ["extension-module", "serde", "macros"] tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "time", "sync", "net", "io-util"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] } + +gel_auth = { path = "rust/auth" } conn_pool = { path = "rust/conn_pool" } pgrust = { path = "rust/pgrust" } http = { path = "rust/http" } diff --git a/rust/auth/Cargo.toml b/rust/auth/Cargo.toml new file mode 100644 index 00000000000..f4ecf765b12 --- /dev/null +++ b/rust/auth/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "gel_auth" +version = "0.1.0" +license = "MIT/Apache-2.0" +authors = ["MagicStack Inc. "] +edition = "2021" + +[lints] +workspace = true + +[dependencies] +tracing.workspace = true + +rand = "0.8.5" +md5 = "0.7.0" +sha2 = "0.10.8" +roaring = "0.10.6" +constant_time_eq = "0.3" +base64 = "0.22" +unicode-normalization = "0.1.23" +thiserror = "2" +hmac = "0.12.1" +derive_more = { version = "1", features = ["debug"] } + +[dev-dependencies] +pretty_assertions = "1" +rstest = "0.23.0" +hex-literal = "0.4.1" + +[lib] diff --git a/rust/auth/README.md b/rust/auth/README.md new file mode 100644 index 00000000000..b904dc84186 --- /dev/null +++ b/rust/auth/README.md @@ -0,0 +1,8 @@ +# gel_auth + +Contains authentication routines for all supported auth methods for PostgreSQL and EdgeDB: + + - Plaintext + - MD5 + - SCRAM (`SCRAM-SHA-256` only) + diff --git a/rust/auth/src/handshake/mod.rs b/rust/auth/src/handshake/mod.rs new file mode 100644 index 00000000000..30b02633e58 --- /dev/null +++ b/rust/auth/src/handshake/mod.rs @@ -0,0 +1,3 @@ +mod server_auth; + +pub use server_auth::*; diff --git a/rust/pgrust/src/handshake/server_auth.rs b/rust/auth/src/handshake/server_auth.rs similarity index 98% rename from rust/pgrust/src/handshake/server_auth.rs rename to rust/auth/src/handshake/server_auth.rs index 7a6c0b9a01c..5f560237f38 100644 --- a/rust/pgrust/src/handshake/server_auth.rs +++ b/rust/auth/src/handshake/server_auth.rs @@ -1,4 +1,8 @@ -use crate::auth::{AuthType, CredentialData, SCRAMError, ServerTransaction, StoredHash, StoredKey}; +use crate::{ + md5::StoredHash, + scram::{SCRAMError, ServerTransaction, StoredKey}, + AuthType, CredentialData, +}; use tracing::error; #[derive(Debug)] diff --git a/rust/pgrust/src/auth/mod.rs b/rust/auth/src/lib.rs similarity index 82% rename from rust/pgrust/src/auth/mod.rs rename to rust/auth/src/lib.rs index 2a08588e803..b4b51a700d3 100644 --- a/rust/pgrust/src/auth/mod.rs +++ b/rust/auth/src/lib.rs @@ -1,15 +1,10 @@ -mod md5; -mod scram; -mod stringprep; +pub mod handshake; +pub mod md5; +pub mod scram; +pub mod stringprep; mod stringprep_table; -pub use md5::{md5_password, StoredHash}; use rand::Rng; -pub use scram::{ - generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError, ServerEnvironment, - ServerTransaction, Sha256Out, StoredKey, -}; -pub use stringprep::{sasl_normalize_password, sasl_normalize_password_bytes}; /// Specifies the type of authentication or indicates the authentication method used for a connection. #[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] @@ -49,9 +44,9 @@ pub enum CredentialData { /// A plain-text password. Plain(String), /// A stored MD5 hash + salt. - Md5(StoredHash), + Md5(md5::StoredHash), /// A stored SCRAM-SHA-256 key. - Scram(StoredKey), + Scram(scram::StoredKey), } impl CredentialData { @@ -60,10 +55,10 @@ impl CredentialData { AuthType::Deny => Self::Deny, AuthType::Trust => Self::Trust, AuthType::Plain => Self::Plain(password), - AuthType::Md5 => Self::Md5(StoredHash::generate(password.as_bytes(), &username)), + AuthType::Md5 => Self::Md5(md5::StoredHash::generate(password.as_bytes(), &username)), AuthType::ScramSha256 => { let salt: [u8; 32] = rand::thread_rng().gen(); - Self::Scram(StoredKey::generate(password.as_bytes(), &salt, 4096)) + Self::Scram(scram::StoredKey::generate(password.as_bytes(), &salt, 4096)) } } } diff --git a/rust/pgrust/src/auth/md5.rs b/rust/auth/src/md5.rs similarity index 97% rename from rust/pgrust/src/auth/md5.rs rename to rust/auth/src/md5.rs index c44ca531a02..2f5e03f2177 100644 --- a/rust/pgrust/src/auth/md5.rs +++ b/rust/auth/src/md5.rs @@ -14,10 +14,11 @@ /// # Example /// /// ``` +/// # use gel_auth::md5::*; /// let password = "secret"; /// let username = "user"; /// let salt = [0x01, 0x02, 0x03, 0x04]; -/// let hash = pgrust::auth::md5_password(password, username, &salt); +/// let hash = md5_password(password, username, &salt); /// assert_eq!(hash, "md5fccef98e4f1cf6cbe96b743fad4e8bd0"); /// ``` pub fn md5_password(password: &str, username: &str, salt: &[u8; 4]) -> String { diff --git a/rust/pgrust/src/auth/scram.rs b/rust/auth/src/scram.rs similarity index 99% rename from rust/pgrust/src/auth/scram.rs rename to rust/auth/src/scram.rs index 668c74cbe20..bcf055d9aae 100644 --- a/rust/pgrust/src/auth/scram.rs +++ b/rust/auth/src/scram.rs @@ -74,7 +74,7 @@ use sha2::{digest::FixedOutput, Digest, Sha256}; use std::borrow::Cow; use std::str::FromStr; -use super::sasl_normalize_password_bytes; +use crate::stringprep::sasl_normalize_password_bytes; const CHANNEL_BINDING_ENCODED: &str = "biws"; const MINIMUM_NONCE_LENGTH: usize = 16; @@ -776,6 +776,7 @@ fn generate_server_proof( mod tests { use super::*; use hex_literal::hex; + use pretty_assertions::{assert_eq, assert_ne}; use rstest::rstest; // Define a set of test parameters diff --git a/rust/pgrust/src/auth/stringprep.rs b/rust/auth/src/stringprep.rs similarity index 98% rename from rust/pgrust/src/auth/stringprep.rs rename to rust/auth/src/stringprep.rs index 67c57d9b227..89c3bc37d48 100644 --- a/rust/pgrust/src/auth/stringprep.rs +++ b/rust/auth/src/stringprep.rs @@ -8,8 +8,7 @@ use unicode_normalization::UnicodeNormalization; /// # Examples /// /// ``` -/// use pgrust::auth::sasl_normalize_password_bytes; -/// +/// # use gel_auth::stringprep::*; /// assert_eq!(sasl_normalize_password_bytes(b"password").as_ref(), b"password"); /// assert_eq!(sasl_normalize_password_bytes("passw\u{00A0}rd".as_bytes()).as_ref(), b"passw rd"); /// assert_eq!(sasl_normalize_password_bytes("pass\u{200B}word".as_bytes()).as_ref(), b"password"); @@ -36,8 +35,7 @@ pub fn sasl_normalize_password_bytes(s: &[u8]) -> Cow<[u8]> { /// # Examples /// /// ``` -/// use pgrust::auth::sasl_normalize_password; -/// +/// # use gel_auth::stringprep::*; /// assert_eq!(sasl_normalize_password("password").as_ref(), "password"); /// assert_eq!(sasl_normalize_password("passw\u{00A0}rd").as_ref(), "passw rd"); /// assert_eq!(sasl_normalize_password("pass\u{200B}word").as_ref(), "password"); diff --git a/rust/pgrust/src/auth/stringprep_table.rs b/rust/auth/src/stringprep_table.rs similarity index 100% rename from rust/pgrust/src/auth/stringprep_table.rs rename to rust/auth/src/stringprep_table.rs diff --git a/rust/pgrust/src/auth/stringprep_table_prep.py b/rust/auth/src/stringprep_table_prep.py similarity index 100% rename from rust/pgrust/src/auth/stringprep_table_prep.py rename to rust/auth/src/stringprep_table_prep.py diff --git a/rust/http/Cargo.toml b/rust/http/Cargo.toml index bba08eabf02..5cd39bc7eca 100644 --- a/rust/http/Cargo.toml +++ b/rust/http/Cargo.toml @@ -23,10 +23,6 @@ eventsource-stream = "0.2.3" futures = "0" -[dependencies.derive_more] -version = "1.0.0" -features = ["full"] - [dev-dependencies] tokio = { workspace = true, features = ["test-util"] } rstest = "0.23" diff --git a/rust/pgrust/Cargo.toml b/rust/pgrust/Cargo.toml index 535341ac1f6..b654139e944 100644 --- a/rust/pgrust/Cargo.toml +++ b/rust/pgrust/Cargo.toml @@ -13,6 +13,7 @@ python_extension = ["pyo3/extension-module", "pyo3/serde"] optimizer = [] [dependencies] +gel_auth.workspace = true pyo3.workspace = true tokio.workspace = true tracing.workspace = true @@ -22,19 +23,13 @@ thiserror = "1" openssl = { version = "0.10.66", features = ["v111"] } tokio-openssl = "0.6.4" paste = "1" -unicode-normalization = "0.1.23" -hmac = "0.12" base64 = "0.22" -sha2 = "0.10" -md5 = "0.7.0" rand = "0" hexdump = "0" url = "2" serde = "1" serde_derive = "1" percent-encoding = "2" -roaring = "0.10.6" -constant_time_eq = "0.3" uuid = "1" [dependencies.derive_more] @@ -47,14 +42,9 @@ scopeguard = "1" pretty_assertions = "1.2.0" test-log = { version = "0", features = ["trace"] } -anyhow = "1" rstest = "0" -statrs = "0" -lru = "0" -byteorder = "1.5" clap = "4" clap_derive = "4" -hex-literal = "0.4" tempfile = "3" socket2 = "0.5.7" libc = "0.2.158" diff --git a/rust/pgrust/src/connection/mod.rs b/rust/pgrust/src/connection/mod.rs index aaac05e07c8..e15be003092 100644 --- a/rust/pgrust/src/connection/mod.rs +++ b/rust/pgrust/src/connection/mod.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use crate::{ - auth, errors::{edgedb::EdbError, PgServerError}, protocol::ParseError, }; @@ -52,7 +51,7 @@ pub enum ConnectionError { /// Error related to SCRAM authentication. #[error("SCRAM: {0}")] - Scram(#[from] auth::SCRAMError), + Scram(#[from] gel_auth::scram::SCRAMError), /// I/O error encountered during connection operations. #[error("I/O error: {0}")] diff --git a/rust/pgrust/src/connection/raw_conn.rs b/rust/pgrust/src/connection/raw_conn.rs index a19c95216e5..24d16885e23 100644 --- a/rust/pgrust/src/connection/raw_conn.rs +++ b/rust/pgrust/src/connection/raw_conn.rs @@ -9,11 +9,9 @@ use crate::handshake::{ }, ConnectionSslRequirement, }; +use crate::protocol::postgres::{FrontendBuilder, InitialBuilder}; use crate::protocol::{postgres::data::SSLResponse, postgres::meta, StructBuffer}; -use crate::{ - auth::AuthType, - protocol::postgres::{FrontendBuilder, InitialBuilder}, -}; +use gel_auth::AuthType; use std::collections::HashMap; use std::pin::Pin; use std::task::{Context, Poll}; diff --git a/rust/pgrust/src/handshake/client_state_machine.rs b/rust/pgrust/src/handshake/client_state_machine.rs index 32524cac5de..7723f22b3d9 100644 --- a/rust/pgrust/src/handshake/client_state_machine.rs +++ b/rust/pgrust/src/handshake/client_state_machine.rs @@ -1,8 +1,5 @@ use super::ConnectionSslRequirement; use crate::{ - auth::{ - self, generate_salted_password, AuthType, ClientEnvironment, ClientTransaction, Sha256Out, - }, connection::{invalid_state, ConnectionError, Credentials, SslError}, errors::PgServerError, protocol::{ @@ -18,6 +15,10 @@ use crate::{ }, }; use base64::Engine; +use gel_auth::{ + scram::{generate_salted_password, ClientEnvironment, ClientTransaction, Sha256Out}, + AuthType, +}; use rand::Rng; use tracing::{error, trace, warn}; @@ -221,7 +222,7 @@ impl ConnectionState { let mut tx = ClientTransaction::new("".into()); let env = ClientEnvironmentImpl { credentials }; let Some(initial_message) = tx.process_message(&[], &env)? else { - return Err(auth::SCRAMError::ProtocolError.into()); + return Err(gel_auth::scram::SCRAMError::ProtocolError.into()); }; update.auth(AuthType::ScramSha256); update.send(builder::SASLInitialResponse { @@ -234,7 +235,7 @@ impl ConnectionState { (AuthenticationMD5Password as md5) => { *sent_auth = true; trace!("auth md5"); - let md5_hash = auth::md5_password(&credentials.password, &credentials.username, &md5.salt()); + let md5_hash = gel_auth::md5::md5_password(&credentials.password, &credentials.username, &md5.salt()); update.auth(AuthType::Md5); update.send(builder::PasswordMessage { password: &md5_hash, @@ -263,7 +264,7 @@ impl ConnectionState { match_message!(message, Backend { (AuthenticationSASLContinue as sasl) => { let Some(message) = tx.process_message(&sasl.data(), env)? else { - return Err(auth::SCRAMError::ProtocolError.into()); + return Err(gel_auth::scram::SCRAMError::ProtocolError.into()); }; update.send(builder::SASLResponse { response: &message, @@ -271,7 +272,7 @@ impl ConnectionState { }, (AuthenticationSASLFinal as sasl) => { let None = tx.process_message(&sasl.data(), env)? else { - return Err(auth::SCRAMError::ProtocolError.into()); + return Err(gel_auth::scram::SCRAMError::ProtocolError.into()); }; }, (AuthenticationOk) => { diff --git a/rust/pgrust/src/handshake/edgedb_server.rs b/rust/pgrust/src/handshake/edgedb_server.rs index d1c236c5d4a..faf3669a2e1 100644 --- a/rust/pgrust/src/handshake/edgedb_server.rs +++ b/rust/pgrust/src/handshake/edgedb_server.rs @@ -1,14 +1,15 @@ -use super::server_auth::{ServerAuth, ServerAuthError}; use crate::{ - auth::{AuthType, CredentialData}, connection::ConnectionError, errors::edgedb::EdbError, - handshake::server_auth::{ServerAuthDrive, ServerAuthResponse}, protocol::{ edgedb::{data::*, *}, match_message, ParseError, StructBuffer, }, }; +use gel_auth::{ + handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse}, + AuthType, CredentialData, +}; use std::str::Utf8Error; use tracing::{error, trace, warn}; @@ -32,7 +33,12 @@ pub enum ConnectionDrive<'a> { pub trait ConnectionStateSend { fn send(&mut self, message: EdgeDBBackendBuilder) -> Result<(), std::io::Error>; - fn auth(&mut self, user: String, database: String, branch: String) -> Result<(), std::io::Error>; + fn auth( + &mut self, + user: String, + database: String, + branch: String, + ) -> Result<(), std::io::Error>; fn params(&mut self) -> Result<(), std::io::Error>; } @@ -61,7 +67,12 @@ where self(ConnectionEvent::Send(message)) } - fn auth(&mut self, user: String, database: String, branch: String) -> Result<(), std::io::Error> { + fn auth( + &mut self, + user: String, + database: String, + branch: String, + ) -> Result<(), std::io::Error> { self(ConnectionEvent::Auth(user, database, branch)) } @@ -367,6 +378,7 @@ fn send_error( )) } +#[allow(unused)] enum ErrorSeverity { Error = 0x78, Fatal = 0xc8, diff --git a/rust/pgrust/src/handshake/mod.rs b/rust/pgrust/src/handshake/mod.rs index 70b3bea9185..a253e3bbde6 100644 --- a/rust/pgrust/src/handshake/mod.rs +++ b/rust/pgrust/src/handshake/mod.rs @@ -11,7 +11,6 @@ pub enum ConnectionSslRequirement { mod client_state_machine; pub mod edgedb_server; -mod server_auth; mod server_state_machine; pub mod client { @@ -30,11 +29,11 @@ mod tests { ConnectionSslRequirement, }; use crate::{ - auth::{AuthType, CredentialData}, connection::Credentials, errors::{PgError, PgServerError}, protocol::postgres::{data::*, *}, }; + use gel_auth::{AuthType, CredentialData}; use rstest::rstest; use std::collections::VecDeque; diff --git a/rust/pgrust/src/handshake/server_state_machine.rs b/rust/pgrust/src/handshake/server_state_machine.rs index f8f3c2a506b..8c62c00b610 100644 --- a/rust/pgrust/src/handshake/server_state_machine.rs +++ b/rust/pgrust/src/handshake/server_state_machine.rs @@ -1,21 +1,20 @@ -use super::{ - server_auth::{ServerAuth, ServerAuthError}, - ConnectionSslRequirement, -}; +use super::ConnectionSslRequirement; use crate::{ - auth::{AuthType, CredentialData}, connection::ConnectionError, errors::{ PgError, PgErrorConnectionException, PgErrorFeatureNotSupported, PgErrorInvalidAuthorizationSpecification, PgServerError, PgServerErrorField, }, - handshake::server_auth::{ServerAuthDrive, ServerAuthResponse}, protocol::{ match_message, postgres::{data::*, *}, ParseError, StructBuffer, }, }; +use gel_auth::{ + handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse}, + AuthType, CredentialData, +}; use std::str::Utf8Error; use tracing::{error, trace, warn}; diff --git a/rust/pgrust/src/lib.rs b/rust/pgrust/src/lib.rs index 1948eeedbd1..dcf7bf89569 100644 --- a/rust/pgrust/src/lib.rs +++ b/rust/pgrust/src/lib.rs @@ -1,4 +1,3 @@ -pub mod auth; pub mod connection; pub mod errors; pub mod handshake; diff --git a/rust/pgrust/src/python.rs b/rust/pgrust/src/python.rs index a9aef3509d5..0914e181367 100644 --- a/rust/pgrust/src/python.rs +++ b/rust/pgrust/src/python.rs @@ -474,7 +474,7 @@ impl ConnectionStateUpdate for PyConnectionStateUpdate { }); } - fn auth(&mut self, auth: crate::auth::AuthType) { + fn auth(&mut self, auth: gel_auth::AuthType) { Python::with_gil(|py| { if let Err(e) = self.py_update.call_method1(py, "auth", (auth as u8,)) { eprintln!("Error in auth: {:?}", e); diff --git a/rust/pgrust/tests/real_postgres.rs b/rust/pgrust/tests/real_postgres.rs index a263ee3fc67..d9238e6f5d9 100644 --- a/rust/pgrust/tests/real_postgres.rs +++ b/rust/pgrust/tests/real_postgres.rs @@ -1,6 +1,6 @@ // Constants +use gel_auth::AuthType; use openssl::ssl::{Ssl, SslContext, SslMethod}; -use pgrust::auth::AuthType; use pgrust::connection::dsn::{Host, HostType}; use pgrust::connection::{connect_raw_ssl, ConnectionError, Credentials, ResolvedTarget}; use pgrust::errors::PgServerError; From cb22e82e6105513e800b5c88b9f9a3b5683a0629 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 26 Nov 2024 12:40:02 -0700 Subject: [PATCH 6/6] Cleanup --- edb/server/_rust_native/Cargo.toml | 2 +- rust/pgrust/Cargo.toml | 2 +- rust/pgrust/src/protocol/gen.rs | 4 ---- rust/pyo3_util/Cargo.toml | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/edb/server/_rust_native/Cargo.toml b/edb/server/_rust_native/Cargo.toml index 08b6927a797..63c1c0a4b75 100644 --- a/edb/server/_rust_native/Cargo.toml +++ b/edb/server/_rust_native/Cargo.toml @@ -5,7 +5,7 @@ license = "MIT/Apache-2.0" authors = ["MagicStack Inc. "] edition = "2021" -[lint] +[lints] workspace = true [features] diff --git a/rust/pgrust/Cargo.toml b/rust/pgrust/Cargo.toml index b654139e944..f45984a4169 100644 --- a/rust/pgrust/Cargo.toml +++ b/rust/pgrust/Cargo.toml @@ -5,7 +5,7 @@ license = "MIT/Apache-2.0" authors = ["MagicStack Inc. "] edition = "2021" -[lint] +[lints] workspace = true [features] diff --git a/rust/pgrust/src/protocol/gen.rs b/rust/pgrust/src/protocol/gen.rs index c44f87956c8..19dca193326 100644 --- a/rust/pgrust/src/protocol/gen.rs +++ b/rust/pgrust/src/protocol/gen.rs @@ -689,7 +689,6 @@ mod tests { } mod mixed { - use crate::protocol::meta::ZTString; protocol!(struct Mixed { a: u8 = 1, s: ZTString, @@ -697,7 +696,6 @@ mod tests { } mod docs { - use crate::protocol::meta::ZTString; protocol!( /// Docs struct Docs { @@ -710,7 +708,6 @@ mod tests { } mod length { - use crate::protocol::meta::Length; protocol!( struct WithLength { a: u8, @@ -729,7 +726,6 @@ mod tests { } mod string { - use crate::protocol::meta::LString; protocol!( struct HasLString { s: LString, diff --git a/rust/pyo3_util/Cargo.toml b/rust/pyo3_util/Cargo.toml index ea097ebbeb6..05b59f22094 100644 --- a/rust/pyo3_util/Cargo.toml +++ b/rust/pyo3_util/Cargo.toml @@ -5,7 +5,7 @@ license = "MIT/Apache-2.0" authors = ["MagicStack Inc. "] edition = "2021" -[lint] +[lints] workspace = true [dependencies]