Skip to content

Commit

Permalink
add ssl negotiation for pg 17
Browse files Browse the repository at this point in the history
  • Loading branch information
fakeshadow committed Sep 29, 2024
1 parent 11a72b8 commit e7f1e27
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 24 deletions.
37 changes: 36 additions & 1 deletion postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,29 @@ use std::{

use super::{error::Error, session::TargetSessionAttrs};

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
#[non_exhaustive]
pub enum SslMode {
/// Do not use TLS.
Disable,
/// Attempt to connect with TLS but allow sessions without.
#[default]
Prefer,
/// Require the use of TLS.
Require,
}

/// TLS negotiation configuration
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
#[non_exhaustive]
pub enum SslNegotiation {
/// Use PostgreSQL SslRequest for Ssl negotiation
#[default]
Postgres,
/// Start Ssl handshake without negotiation, only works for PostgreSQL 17+
Direct,
}

/// A host specification.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Host {
Expand All @@ -38,6 +50,7 @@ pub struct Config {
pub(crate) options: Option<Box<str>>,
pub(crate) application_name: Option<Box<str>>,
pub(crate) ssl_mode: SslMode,
pub(crate) ssl_negotiation: SslNegotiation,
pub(crate) host: Vec<Host>,
pub(crate) port: Vec<u16>,
target_session_attrs: TargetSessionAttrs,
Expand All @@ -60,6 +73,7 @@ impl Config {
options: None,
application_name: None,
ssl_mode: SslMode::Prefer,
ssl_negotiation: SslNegotiation::Postgres,
host: Vec::new(),
port: Vec::new(),
target_session_attrs: TargetSessionAttrs::Any,
Expand Down Expand Up @@ -147,6 +161,19 @@ impl Config {
self.ssl_mode
}

/// Sets the SSL negotiation method.
///
/// Defaults to `postgres`.
pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
self.ssl_negotiation = ssl_negotiation;
self
}

/// Gets the SSL negotiation method.
pub fn get_ssl_negotiation(&self) -> SslNegotiation {
self.ssl_negotiation
}

pub fn host(&mut self, host: &str) -> &mut Config {
if host.starts_with('/') {
return self.host_path(host);
Expand Down Expand Up @@ -306,6 +333,14 @@ impl Config {
};
self.ssl_mode(mode);
}
"sslnegotiation" => {
let mode = match value {
"postgres" => SslNegotiation::Postgres,
"direct" => SslNegotiation::Direct,
_ => return Err(Error::todo()),
};
self.ssl_negotiation(mode);
}
"host" => {
for host in value.split(',') {
self.host(host);
Expand Down
5 changes: 3 additions & 2 deletions postgres/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use xitca_io::{

use super::{
client::Client,
config::{Config, SslMode},
config::{Config, SslMode, SslNegotiation},
error::{unexpected_eof_err, ConfigError, Error},
iter::AsyncLendingIterator,
session::{ConnectInfo, Session},
Expand Down Expand Up @@ -90,7 +90,7 @@ where
Ok((tx, session, drv))
}

async fn should_connect_tls<Io>(io: &mut Io, ssl_mode: SslMode) -> Result<bool, Error>
async fn should_connect_tls<Io>(io: &mut Io, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Result<bool, Error>
where
Io: AsyncIo,
{
Expand Down Expand Up @@ -127,6 +127,7 @@ where

match ssl_mode {
SslMode::Disable => Ok(false),
_ if matches!(ssl_negotiation, SslNegotiation::Direct) => Ok(true),
mode => match (query_tls_availability(io).await?, mode) {
(false, SslMode::Require) => Err(Error::todo()),
(bool, _) => Ok(bool),
Expand Down
25 changes: 15 additions & 10 deletions postgres/src/driver/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver
}

let ssl_mode = cfg.get_ssl_mode();
let ssl_negotiation = cfg.get_ssl_negotiation();

match host {
Host::Tcp(host) => {
let (mut io, addr) = connect_tcp(&host, cfg.get_ports()).await?;
if should_connect_tls(&mut io, ssl_mode).await? {
if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? {
#[cfg(feature = "tls")]
{
let io = super::tls::connect_tls(io, &host, cfg).await?;
let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode);
let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode, ssl_negotiation);
prepare_driver(info, io, cfg)
.await
.map(|(tx, session, drv)| (tx, session, Driver::Tls(drv)))
Expand All @@ -54,7 +55,7 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver
Err(crate::error::FeatureError::Tls.into())
}
} else {
let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode);
let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode, ssl_negotiation);
prepare_driver(info, io, cfg)
.await
.map(|(tx, session, drv)| (tx, session, Driver::Tcp(drv)))
Expand All @@ -66,11 +67,11 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver
Host::Unix(host) => {
let mut io = xitca_io::net::UnixStream::connect(&host).await?;
let host_str: Box<str> = host.to_string_lossy().into();
if should_connect_tls(&mut io, ssl_mode).await? {
if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? {
#[cfg(feature = "tls")]
{
let io = super::tls::connect_tls(io, host_str.as_ref(), cfg).await?;
let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode);
let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode, ssl_negotiation);
prepare_driver(info, io, cfg)
.await
.map(|(tx, session, drv)| (tx, session, Driver::UnixTls(drv)))
Expand All @@ -80,7 +81,7 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver
Err(crate::error::FeatureError::Tls.into())
}
} else {
let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode);
let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode, ssl_negotiation);
prepare_driver(info, io, cfg)
.await
.map(|(tx, session, drv)| (tx, session, Driver::Unix(drv)))
Expand All @@ -91,7 +92,7 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver
#[cfg(feature = "quic")]
Host::Quic(host) => {
let (io, addr) = super::quic::connect_quic(&host, cfg.get_ports()).await?;
let info = ConnectInfo::new(Addr::Quic(host, addr), ssl_mode);
let info = ConnectInfo::new(Addr::Quic(host, addr), ssl_mode, ssl_negotiation);
prepare_driver(info, io, cfg)
.await
.map(|(tx, session, drv)| (tx, session, Driver::Quic(drv)))
Expand All @@ -102,13 +103,17 @@ pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(Driver
#[cold]
#[inline(never)]
pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> {
let ConnectInfo { addr, ssl_mode } = info;
let ConnectInfo {
addr,
ssl_mode,
ssl_negotiation,
} = info;
match addr {
Addr::Tcp(_host, addr) => {
let mut io = TcpStream::connect(addr).await?;
let _ = io.set_nodelay(true);

if should_connect_tls(&mut io, ssl_mode).await? {
if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? {
#[cfg(feature = "tls")]
{
let io = super::tls::connect_tls(io, &_host, &mut Config::default()).await?;
Expand All @@ -127,7 +132,7 @@ pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver)
#[cfg(unix)]
Addr::Unix(_host, path) => {
let mut io = xitca_io::net::UnixStream::connect(path).await?;
if should_connect_tls(&mut io, ssl_mode).await? {
if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? {
#[cfg(feature = "tls")]
{
let io = super::tls::connect_tls(io, &_host, &mut Config::default()).await?;
Expand Down
22 changes: 11 additions & 11 deletions postgres/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use postgres_protocol::{
use xitca_io::{bytes::BytesMut, io::AsyncIo};

use super::{
config::{Config, SslMode},
config::{Config, SslMode, SslNegotiation},
driver::generic::GenericDriver,
error::{AuthenticationError, Error},
};
Expand All @@ -35,32 +35,32 @@ pub struct Session {
pub(crate) info: ConnectInfo,
}

#[derive(Clone)]
#[derive(Clone, Default)]
pub(crate) struct ConnectInfo {
pub(crate) addr: Addr,
pub(crate) ssl_mode: SslMode,
}

impl Default for ConnectInfo {
fn default() -> Self {
Self::new(Addr::None, SslMode::Disable)
}
pub(crate) ssl_negotiation: SslNegotiation,
}

impl ConnectInfo {
pub(crate) fn new(addr: Addr, ssl_mode: SslMode) -> Self {
Self { addr, ssl_mode }
pub(crate) fn new(addr: Addr, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Self {
Self {
addr,
ssl_mode,
ssl_negotiation,
}
}
}

#[derive(Clone)]
#[derive(Clone, Default)]
pub(crate) enum Addr {
Tcp(Box<str>, SocketAddr),
#[cfg(unix)]
Unix(Box<str>, std::path::PathBuf),
#[cfg(feature = "quic")]
Quic(Box<str>, SocketAddr),
// case for where io is supplied by user and no connectivity can be done from this crate
#[default]
None,
}

Expand Down

0 comments on commit e7f1e27

Please sign in to comment.