From e7f1e27f54d9ca1f5f088092918a55d258717a31 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Sun, 29 Sep 2024 19:52:15 +0800 Subject: [PATCH] add ssl negotiation for pg 17 --- postgres/src/config.rs | 37 +++++++++++++++++++++++++++++++++- postgres/src/driver.rs | 5 +++-- postgres/src/driver/connect.rs | 25 ++++++++++++++--------- postgres/src/session.rs | 22 ++++++++++---------- 4 files changed, 65 insertions(+), 24 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 31f42866..955a775f 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -9,17 +9,29 @@ use std::{ use super::{error::Error, session::TargetSessionAttrs}; -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] #[non_exhaustive] pub enum SslMode { /// Do not use TLS. Disable, /// Attempt to connect with TLS but allow sessions without. + #[default] Prefer, /// Require the use of TLS. Require, } +/// TLS negotiation configuration +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] +#[non_exhaustive] +pub enum SslNegotiation { + /// Use PostgreSQL SslRequest for Ssl negotiation + #[default] + Postgres, + /// Start Ssl handshake without negotiation, only works for PostgreSQL 17+ + Direct, +} + /// A host specification. #[derive(Clone, Debug, Eq, PartialEq)] pub enum Host { @@ -38,6 +50,7 @@ pub struct Config { pub(crate) options: Option>, pub(crate) application_name: Option>, pub(crate) ssl_mode: SslMode, + pub(crate) ssl_negotiation: SslNegotiation, pub(crate) host: Vec, pub(crate) port: Vec, target_session_attrs: TargetSessionAttrs, @@ -60,6 +73,7 @@ impl Config { options: None, application_name: None, ssl_mode: SslMode::Prefer, + ssl_negotiation: SslNegotiation::Postgres, host: Vec::new(), port: Vec::new(), target_session_attrs: TargetSessionAttrs::Any, @@ -147,6 +161,19 @@ impl Config { self.ssl_mode } + /// Sets the SSL negotiation method. + /// + /// Defaults to `postgres`. + pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config { + self.ssl_negotiation = ssl_negotiation; + self + } + + /// Gets the SSL negotiation method. + pub fn get_ssl_negotiation(&self) -> SslNegotiation { + self.ssl_negotiation + } + pub fn host(&mut self, host: &str) -> &mut Config { if host.starts_with('/') { return self.host_path(host); @@ -306,6 +333,14 @@ impl Config { }; self.ssl_mode(mode); } + "sslnegotiation" => { + let mode = match value { + "postgres" => SslNegotiation::Postgres, + "direct" => SslNegotiation::Direct, + _ => return Err(Error::todo()), + }; + self.ssl_negotiation(mode); + } "host" => { for host in value.split(',') { self.host(host); diff --git a/postgres/src/driver.rs b/postgres/src/driver.rs index f2457a0d..f47dcd28 100644 --- a/postgres/src/driver.rs +++ b/postgres/src/driver.rs @@ -30,7 +30,7 @@ use xitca_io::{ use super::{ client::Client, - config::{Config, SslMode}, + config::{Config, SslMode, SslNegotiation}, error::{unexpected_eof_err, ConfigError, Error}, iter::AsyncLendingIterator, session::{ConnectInfo, Session}, @@ -90,7 +90,7 @@ where Ok((tx, session, drv)) } -async fn should_connect_tls(io: &mut Io, ssl_mode: SslMode) -> Result +async fn should_connect_tls(io: &mut Io, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Result where Io: AsyncIo, { @@ -127,6 +127,7 @@ where match ssl_mode { SslMode::Disable => Ok(false), + _ if matches!(ssl_negotiation, SslNegotiation::Direct) => Ok(true), mode => match (query_tls_availability(io).await?, mode) { (false, SslMode::Require) => Err(Error::todo()), (bool, _) => Ok(bool), diff --git a/postgres/src/driver/connect.rs b/postgres/src/driver/connect.rs index 6622008a..c8fd6404 100644 --- a/postgres/src/driver/connect.rs +++ b/postgres/src/driver/connect.rs @@ -36,15 +36,16 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver } let ssl_mode = cfg.get_ssl_mode(); + let ssl_negotiation = cfg.get_ssl_negotiation(); match host { Host::Tcp(host) => { let (mut io, addr) = connect_tcp(&host, cfg.get_ports()).await?; - if should_connect_tls(&mut io, ssl_mode).await? { + if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? { #[cfg(feature = "tls")] { let io = super::tls::connect_tls(io, &host, cfg).await?; - let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode); + let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode, ssl_negotiation); prepare_driver(info, io, cfg) .await .map(|(tx, session, drv)| (tx, session, Driver::Tls(drv))) @@ -54,7 +55,7 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver Err(crate::error::FeatureError::Tls.into()) } } else { - let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode); + let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode, ssl_negotiation); prepare_driver(info, io, cfg) .await .map(|(tx, session, drv)| (tx, session, Driver::Tcp(drv))) @@ -66,11 +67,11 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver Host::Unix(host) => { let mut io = xitca_io::net::UnixStream::connect(&host).await?; let host_str: Box = host.to_string_lossy().into(); - if should_connect_tls(&mut io, ssl_mode).await? { + if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? { #[cfg(feature = "tls")] { let io = super::tls::connect_tls(io, host_str.as_ref(), cfg).await?; - let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode); + let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode, ssl_negotiation); prepare_driver(info, io, cfg) .await .map(|(tx, session, drv)| (tx, session, Driver::UnixTls(drv))) @@ -80,7 +81,7 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver Err(crate::error::FeatureError::Tls.into()) } } else { - let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode); + let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode, ssl_negotiation); prepare_driver(info, io, cfg) .await .map(|(tx, session, drv)| (tx, session, Driver::Unix(drv))) @@ -91,7 +92,7 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver #[cfg(feature = "quic")] Host::Quic(host) => { let (io, addr) = super::quic::connect_quic(&host, cfg.get_ports()).await?; - let info = ConnectInfo::new(Addr::Quic(host, addr), ssl_mode); + let info = ConnectInfo::new(Addr::Quic(host, addr), ssl_mode, ssl_negotiation); prepare_driver(info, io, cfg) .await .map(|(tx, session, drv)| (tx, session, Driver::Quic(drv))) @@ -102,13 +103,17 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver #[cold] #[inline(never)] pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> { - let ConnectInfo { addr, ssl_mode } = info; + let ConnectInfo { + addr, + ssl_mode, + ssl_negotiation, + } = info; match addr { Addr::Tcp(_host, addr) => { let mut io = TcpStream::connect(addr).await?; let _ = io.set_nodelay(true); - if should_connect_tls(&mut io, ssl_mode).await? { + if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? { #[cfg(feature = "tls")] { let io = super::tls::connect_tls(io, &_host, &mut Config::default()).await?; @@ -127,7 +132,7 @@ pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver) #[cfg(unix)] Addr::Unix(_host, path) => { let mut io = xitca_io::net::UnixStream::connect(path).await?; - if should_connect_tls(&mut io, ssl_mode).await? { + if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? { #[cfg(feature = "tls")] { let io = super::tls::connect_tls(io, &_host, &mut Config::default()).await?; diff --git a/postgres/src/session.rs b/postgres/src/session.rs index c0122fbc..0bde7e60 100644 --- a/postgres/src/session.rs +++ b/postgres/src/session.rs @@ -10,7 +10,7 @@ use postgres_protocol::{ use xitca_io::{bytes::BytesMut, io::AsyncIo}; use super::{ - config::{Config, SslMode}, + config::{Config, SslMode, SslNegotiation}, driver::generic::GenericDriver, error::{AuthenticationError, Error}, }; @@ -35,25 +35,24 @@ pub struct Session { pub(crate) info: ConnectInfo, } -#[derive(Clone)] +#[derive(Clone, Default)] pub(crate) struct ConnectInfo { pub(crate) addr: Addr, pub(crate) ssl_mode: SslMode, -} - -impl Default for ConnectInfo { - fn default() -> Self { - Self::new(Addr::None, SslMode::Disable) - } + pub(crate) ssl_negotiation: SslNegotiation, } impl ConnectInfo { - pub(crate) fn new(addr: Addr, ssl_mode: SslMode) -> Self { - Self { addr, ssl_mode } + pub(crate) fn new(addr: Addr, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Self { + Self { + addr, + ssl_mode, + ssl_negotiation, + } } } -#[derive(Clone)] +#[derive(Clone, Default)] pub(crate) enum Addr { Tcp(Box, SocketAddr), #[cfg(unix)] @@ -61,6 +60,7 @@ pub(crate) enum Addr { #[cfg(feature = "quic")] Quic(Box, SocketAddr), // case for where io is supplied by user and no connectivity can be done from this crate + #[default] None, }