diff --git a/rust/captive_postgres/src/lib.rs b/rust/captive_postgres/src/lib.rs index 03be2d2ff53..885e8c870aa 100644 --- a/rust/captive_postgres/src/lib.rs +++ b/rust/captive_postgres/src/lib.rs @@ -19,6 +19,202 @@ pub const DEFAULT_USERNAME: &str = "username"; pub const DEFAULT_PASSWORD: &str = "password"; pub const DEFAULT_DATABASE: &str = "postgres"; +use std::collections::HashMap; + +#[derive(Debug, Clone, Default)] +pub enum PostgresBinPath { + #[default] + Path, + Specified(PathBuf), +} + +#[derive(Debug)] +pub struct PostgresBuilder { + auth: AuthType, + bin_path: PostgresBinPath, + data_dir: Option, + server_options: HashMap, + ssl_cert_and_key: Option<(PathBuf, PathBuf)>, + unix_enabled: bool, + debug_level: Option, +} + +impl Default for PostgresBuilder { + fn default() -> Self { + Self { + auth: AuthType::Trust, + bin_path: PostgresBinPath::default(), + data_dir: None, + server_options: HashMap::new(), + ssl_cert_and_key: None, + unix_enabled: false, + debug_level: None, + } + } +} + +impl PostgresBuilder { + pub fn new() -> Self { + Self::default() + } + + /// Attempt to configure the builder to use the default postgres binaries. + /// Returns an error if the binaries are not found. + pub fn with_automatic_bin_path(mut self) -> std::io::Result { + let bindir = postgres_bin_dir()?; + self.bin_path = PostgresBinPath::Specified(bindir); + Ok(self) + } + + /// Configures the builder with a quick networking mode. + pub fn with_automatic_mode(mut self, mode: Mode) -> Self { + match mode { + Mode::Tcp => { + // No special configuration needed for TCP mode + } + Mode::TcpSsl => { + let certs_dir = test_data_dir().join("certs"); + let cert = certs_dir.join("server.cert.pem"); + let key = certs_dir.join("server.key.pem"); + self.ssl_cert_and_key = Some((cert, key)); + } + Mode::Unix => { + self.unix_enabled = true; + } + } + self + } + + pub fn auth(mut self, auth: AuthType) -> Self { + self.auth = auth; + self + } + + pub fn bin_path(mut self, bin_path: impl AsRef) -> Self { + self.bin_path = PostgresBinPath::Specified(bin_path.as_ref().to_path_buf()); + self + } + + pub fn data_dir(mut self, data_dir: PathBuf) -> Self { + self.data_dir = Some(data_dir); + self + } + + pub fn debug_level(mut self, debug_level: u8) -> Self { + self.debug_level = Some(debug_level); + self + } + + pub fn server_option(mut self, key: impl AsRef, value: impl AsRef) -> Self { + self.server_options + .insert(key.as_ref().to_string(), value.as_ref().to_string()); + self + } + + pub fn server_options( + mut self, + server_options: impl IntoIterator, impl AsRef)>, + ) -> Self { + for (key, value) in server_options { + self.server_options + .insert(key.as_ref().to_string(), value.as_ref().to_string()); + } + self + } + + pub fn enable_ssl(mut self, cert_path: PathBuf, key_path: PathBuf) -> Self { + self.ssl_cert_and_key = Some((cert_path, key_path)); + self + } + + pub fn enable_unix(mut self) -> Self { + self.unix_enabled = true; + self + } + + pub fn build(self) -> std::io::Result { + let initdb = match &self.bin_path { + PostgresBinPath::Path => "initdb".into(), + PostgresBinPath::Specified(path) => path.join("initdb"), + }; + let postgres = match &self.bin_path { + PostgresBinPath::Path => "postgres".into(), + PostgresBinPath::Specified(path) => path.join("postgres"), + }; + + if !initdb.exists() { + return Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("initdb executable not found at {}", initdb.display()), + )); + } + if !postgres.exists() { + return Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("postgres executable not found at {}", postgres.display()), + )); + } + + let temp_dir = TempDir::new()?; + let port = EphemeralPort::allocate()?; + let data_dir = self + .data_dir + .unwrap_or_else(|| temp_dir.path().join("data")); + + init_postgres(&initdb, &data_dir, self.auth)?; + let port = port.take(); + + let ssl_config = self.ssl_cert_and_key; + + let (socket_address, socket_path) = if self.unix_enabled { + ( + ListenAddress::Unix(get_unix_socket_path(&data_dir, port)), + Some(&data_dir), + ) + } else { + ( + ListenAddress::Tcp(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port)), + None, + ) + }; + + let tcp_address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port); + + let mut command = Command::new(postgres); + command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .arg("-D") + .arg(&data_dir) + .arg("-h") + .arg(Ipv4Addr::LOCALHOST.to_string()) + .arg("-F") + .arg("-p") + .arg(port.to_string()); + + if let Some(socket_path) = &socket_path { + command.arg("-k").arg(socket_path); + } + + for (key, value) in self.server_options { + command.arg("-c").arg(format!("{}={}", key, value)); + } + + if let Some(debug_level) = self.debug_level { + command.arg("-d").arg(debug_level.to_string()); + } + + let child = run_postgres(command, &data_dir, socket_path, ssl_config, port)?; + + Ok(PostgresProcess { + child, + socket_address, + tcp_address, + temp_dir, + }) + } +} + #[derive(Debug, Clone)] pub enum ListenAddress { Tcp(SocketAddr), @@ -138,6 +334,7 @@ fn init_postgres(initdb: &Path, data_dir: &Path, auth: AuthType) -> std::io::Res .arg("-U") .arg(DEFAULT_USERNAME); + eprintln!("initdb command: {:?}", command); let output = command.output()?; let status = output.status; @@ -158,28 +355,13 @@ fn init_postgres(initdb: &Path, data_dir: &Path, auth: AuthType) -> std::io::Res } fn run_postgres( - postgres_bin: &Path, + mut command: Command, data_dir: &Path, - socket_path: &Path, + socket_path: Option>, ssl: Option<(PathBuf, PathBuf)>, port: u16, ) -> std::io::Result { - let mut command = Command::new(postgres_bin); - command - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .arg("-D") - .arg(data_dir) - .arg("-k") - .arg(socket_path) - .arg("-h") - .arg(Ipv4Addr::LOCALHOST.to_string()) - .arg("-F") - // Useful for debugging - // .arg("-d") - // .arg("5") - .arg("-p") - .arg(port.to_string()); + let socket_path = socket_path.map(|path| path.as_ref().to_owned()); if let Some((cert_path, key_path)) = ssl { let postgres_cert_path = data_dir.join("server.crt"); @@ -223,11 +405,13 @@ fn run_postgres( let mut tcp_socket: Option = None; let mut unix_socket: Option = None; - let unix_socket_path = get_unix_socket_path(socket_path, port); + let unix_socket_path = socket_path.map(|path| get_unix_socket_path(path, port)); let tcp_socket_addr = std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, port)); + let mut db_ready = false; + let mut network_ready = false; - while start_time.elapsed() < STARTUP_TIMEOUT_DURATION { + while start_time.elapsed() < STARTUP_TIMEOUT_DURATION && !network_ready { std::thread::sleep(HOT_LOOP_INTERVAL); match child.try_wait() { Ok(Some(status)) => { @@ -245,19 +429,17 @@ fn run_postgres( } else { continue; } - if unix_socket.is_none() { - unix_socket = std::os::unix::net::UnixStream::connect(&unix_socket_path).ok(); + if let Some(unix_socket_path) = &unix_socket_path { + if unix_socket.is_none() { + unix_socket = std::os::unix::net::UnixStream::connect(unix_socket_path).ok(); + } } if tcp_socket.is_none() { tcp_socket = std::net::TcpStream::connect(tcp_socket_addr).ok(); } - if unix_socket.is_some() && tcp_socket.is_some() { - break; - } - } - if unix_socket.is_some() && tcp_socket.is_some() { - return Ok(child); + network_ready = + (unix_socket_path.is_none() || unix_socket.is_some()) && tcp_socket.is_some(); } // Print status for TCP/unix sockets @@ -276,6 +458,10 @@ fn run_postgres( eprintln!("Unix socket at {unix_socket_path:?} connection failed"); } + if network_ready { + return Ok(child); + } + Err(std::io::Error::new( std::io::ErrorKind::TimedOut, "PostgreSQL failed to start within 30 seconds", @@ -302,8 +488,8 @@ fn postgres_bin_dir() -> std::io::Result { } } -fn get_unix_socket_path(socket_path: &Path, port: u16) -> PathBuf { - socket_path.join(format!(".s.PGSQL.{}", port)) +fn get_unix_socket_path(socket_path: impl AsRef, port: u16) -> PathBuf { + socket_path.as_ref().join(format!(".s.PGSQL.{}", port)) } #[derive(Debug, Clone, Copy)] @@ -323,6 +509,7 @@ pub fn create_ssl_client() -> Result> { pub struct PostgresProcess { child: std::process::Child, pub socket_address: ListenAddress, + pub tcp_address: SocketAddr, #[allow(unused)] temp_dir: TempDir, } @@ -334,51 +521,54 @@ impl Drop for PostgresProcess { } /// Creates and runs a new Postgres server process in a temporary directory. -pub fn setup_postgres( - auth: AuthType, - mode: Mode, -) -> Result, Box> { - let Ok(bindir) = postgres_bin_dir() else { - println!("Skipping test: postgres bin dir not found"); - return Ok(None); - }; - - let initdb = bindir.join("initdb"); - let postgres = bindir.join("postgres"); +pub fn setup_postgres(auth: AuthType, mode: Mode) -> std::io::Result> { + let builder: PostgresBuilder = PostgresBuilder::new(); - if !initdb.exists() || !postgres.exists() { - println!("Skipping test: initdb or postgres not found"); + let Ok(mut builder) = builder.with_automatic_bin_path() else { + eprintln!("Skipping test: postgres bin dir not found"); return Ok(None); - } - - let temp_dir = TempDir::new()?; - let port = EphemeralPort::allocate()?; - let data_dir = temp_dir.path().join("data"); - - init_postgres(&initdb, &data_dir, auth)?; - let ssl_key = match mode { - Mode::TcpSsl => { - let certs_dir = test_data_dir().join("certs"); - let cert = certs_dir.join("server.cert.pem"); - let key = certs_dir.join("server.key.pem"); - Some((cert, key)) - } - _ => None, }; - let port = port.take(); - let child = run_postgres(&postgres, &data_dir, &data_dir, ssl_key, port)?; + builder = builder.auth(auth).with_automatic_mode(mode); - let socket_address = match mode { - Mode::Unix => ListenAddress::Unix(get_unix_socket_path(&data_dir, port)), - Mode::Tcp | Mode::TcpSsl => { - ListenAddress::Tcp(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port)) - } - }; + let process = builder.build()?; + Ok(Some(process)) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_builder_defaults() { + let builder = PostgresBuilder::new(); + assert!(matches!(builder.auth, AuthType::Trust)); + assert!(matches!(builder.bin_path, PostgresBinPath::Path)); + assert!(builder.data_dir.is_none()); + assert_eq!(builder.server_options.len(), 0); + } - Ok(Some(PostgresProcess { - child, - socket_address, - temp_dir, - })) + #[test] + fn test_builder_customization() { + let mut options = HashMap::new(); + options.insert("max_connections", "100"); + + let data_dir = PathBuf::from("/tmp/pg_data"); + let bin_path = PathBuf::from("/usr/local/pgsql/bin"); + + let builder = PostgresBuilder::new() + .auth(AuthType::Md5) + .bin_path(bin_path) + .data_dir(data_dir.clone()) + .server_options(options); + + assert!(matches!(builder.auth, AuthType::Md5)); + assert!(matches!(builder.bin_path, PostgresBinPath::Specified(_))); + assert_eq!(builder.data_dir.unwrap(), data_dir); + assert_eq!( + builder.server_options.get("max_connections").unwrap(), + "100" + ); + } } diff --git a/rust/pgrust/src/handshake/client_state_machine.rs b/rust/pgrust/src/handshake/client_state_machine.rs index 7b7b0dc3cf5..cfba84dda8a 100644 --- a/rust/pgrust/src/handshake/client_state_machine.rs +++ b/rust/pgrust/src/handshake/client_state_machine.rs @@ -93,8 +93,12 @@ pub trait ConnectionStateUpdate: ConnectionStateSend { fn parameter(&mut self, name: &str, value: &str) {} fn cancellation_key(&mut self, pid: i32, key: i32) {} fn state_changed(&mut self, state: ConnectionStateType) {} - fn server_error(&mut self, error: &PgServerError) {} - fn server_notice(&mut self, notice: &PgServerError) {} + fn server_error(&mut self, error: &PgServerError) { + error!("Server error during handshake: {:?}", error); + } + fn server_notice(&mut self, notice: &PgServerError) { + warn!("Server notice during handshake: {:?}", notice); + } fn auth(&mut self, auth: AuthType) {} } diff --git a/rust/pgrust/tests/real_postgres.rs b/rust/pgrust/tests/real_postgres.rs index 0d562e2047e..185bd1bc12a 100644 --- a/rust/pgrust/tests/real_postgres.rs +++ b/rust/pgrust/tests/real_postgres.rs @@ -17,8 +17,41 @@ fn address(address: &ListenAddress) -> ResolvedTarget { } } +/// Ensure that a notice doesn't cause unexpected behavior. +#[test_log::test(tokio::test)] +async fn test_auth_noisy() -> Result<(), Box> { + let Ok(builder) = PostgresBuilder::new().with_automatic_bin_path() else { + return Ok(()); + }; + + let builder = builder + .debug_level(5) + .server_option("client_min_messages", "debug5"); + + let process = builder.build()?; + + let credentials = Credentials { + username: DEFAULT_USERNAME.to_string(), + password: DEFAULT_PASSWORD.to_string(), + database: DEFAULT_DATABASE.to_string(), + server_settings: Default::default(), + }; + + let client = address(&process.socket_address).connect().await?; + + let ssl_requirement = ConnectionSslRequirement::Optional; + + let params = connect_raw_ssl(credentials, ssl_requirement, create_ssl_client()?, client) + .await? + .params() + .clone(); + assert_eq!(params.auth, AuthType::Trust); + + Ok(()) +} + #[rstest] -#[tokio::test] +#[test_log::test(tokio::test)] async fn test_auth_real( #[values(AuthType::Trust, AuthType::Plain, AuthType::Md5, AuthType::ScramSha256)] auth: AuthType,