From 410821d3f3cf008b22b901fc68c227e800611a48 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Thu, 2 Jan 2025 07:53:12 -0700 Subject: [PATCH 1/7] Enhanced query protocol in pgrust (to replace pgcon) (#8055) Support proper pipelining in the Client and the required message parsing to do so. This is not sufficient to replace `pgcon.py*` yet, but a big step towards that. The `--example connect` now supports an `--extended` flag allowing us to send extended messages instead of `Query`. This supports: - Non-extended `Query` messages - Extended query messages in a `Sync`-delimited pipeline (including mixing `Query` messages with those) - `COPY OUT` style statements in both `Query` and `Execute` (not `COPY IN` at this time, however) Unsupported: - `COPY IN` - Async notifications and `ParameterStatus` messages are handled as `warn!` and not exposed yet --------- Co-authored-by: Fantix King --- Cargo.lock | 33 +- Cargo.toml | 2 + rust/captive_postgres/Cargo.toml | 17 + rust/captive_postgres/README.md | 5 + rust/captive_postgres/src/lib.rs | 384 +++++++ rust/pgrust/Cargo.toml | 11 +- rust/pgrust/examples/connect.rs | 182 ++- rust/pgrust/src/connection/conn.rs | 946 ++++++++++++---- rust/pgrust/src/connection/flow.rs | 1234 +++++++++++++++++++++ rust/pgrust/src/connection/mod.rs | 21 +- rust/pgrust/src/connection/queue.rs | 166 +++ rust/pgrust/src/connection/raw_conn.rs | 56 +- rust/pgrust/src/connection/stream.rs | 29 + rust/pgrust/src/protocol/datatypes.rs | 21 +- rust/pgrust/src/protocol/gen.rs | 3 +- rust/pgrust/src/protocol/message_group.rs | 2 +- rust/pgrust/src/protocol/mod.rs | 11 + rust/pgrust/src/protocol/postgres.rs | 4 +- rust/pgrust/tests/query_real_postgres.rs | 354 ++++++ rust/pgrust/tests/real_postgres.rs | 380 +------ 20 files changed, 3242 insertions(+), 619 deletions(-) create mode 100644 rust/captive_postgres/Cargo.toml create mode 100644 rust/captive_postgres/README.md create mode 100644 rust/captive_postgres/src/lib.rs create mode 100644 rust/pgrust/src/connection/flow.rs create mode 100644 rust/pgrust/src/connection/queue.rs create mode 100644 rust/pgrust/tests/query_real_postgres.rs diff --git a/Cargo.lock b/Cargo.lock index 7a3d36cb912..d550dc2d969 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,6 +242,20 @@ name = "bytemuck" version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "byteorder" @@ -255,6 +269,16 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +[[package]] +name = "captive_postgres" +version = "0.1.0" +dependencies = [ + "gel_auth", + "openssl", + "socket2", + "tempfile", +] + [[package]] name = "cardinality-estimator" version = "1.0.2" @@ -1560,11 +1584,14 @@ name = "pgrust" version = "0.1.0" dependencies = [ "base64", + "bytemuck", + "captive_postgres", "clap", "clap_derive", "derive_more", "futures", "gel_auth", + "hex-literal", "hexdump", "libc", "openssl", @@ -1577,8 +1604,6 @@ dependencies = [ "scopeguard", "serde", "serde_derive", - "socket2", - "tempfile", "test-log", "thiserror 1.0.63", "tokio", @@ -2292,9 +2317,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", diff --git a/Cargo.toml b/Cargo.toml index a7ec25a43fd..26ae534004e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "edb/graphql-rewrite", "edb/server/_rust_native", "rust/auth", + "rust/captive_postgres", "rust/conn_pool", "rust/pgrust", "rust/http", @@ -20,6 +21,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] } gel_auth = { path = "rust/auth" } +captive_postgres = { path = "rust/captive_postgres" } conn_pool = { path = "rust/conn_pool" } pgrust = { path = "rust/pgrust" } http = { path = "rust/http" } diff --git a/rust/captive_postgres/Cargo.toml b/rust/captive_postgres/Cargo.toml new file mode 100644 index 00000000000..995a4b6d482 --- /dev/null +++ b/rust/captive_postgres/Cargo.toml @@ -0,0 +1,17 @@ + +[package] +name = "captive_postgres" +version = "0.1.0" +license = "MIT/Apache-2.0" +authors = ["MagicStack Inc. "] +edition = "2021" + +[lints] +workspace = true + +[dependencies] +gel_auth.workspace = true + +openssl = "0.10.55" +tempfile = "3" +socket2 = "0.5.8" diff --git a/rust/captive_postgres/README.md b/rust/captive_postgres/README.md new file mode 100644 index 00000000000..36c9d67cd65 --- /dev/null +++ b/rust/captive_postgres/README.md @@ -0,0 +1,5 @@ +# captive_postgres + +A simple, captive Postgres server that can be used to test client connections. Each instance +is a freshly initialized Postgres server with the specified credentials. + diff --git a/rust/captive_postgres/src/lib.rs b/rust/captive_postgres/src/lib.rs new file mode 100644 index 00000000000..03be2d2ff53 --- /dev/null +++ b/rust/captive_postgres/src/lib.rs @@ -0,0 +1,384 @@ +// Constants +use gel_auth::AuthType; +use openssl::ssl::{Ssl, SslContext, SslMethod}; +use std::io::{BufRead, BufReader, Write}; +use std::net::{Ipv4Addr, SocketAddr, TcpListener}; +use std::os::unix::fs::PermissionsExt; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::{Duration, Instant}; +use tempfile::TempDir; + +pub const STARTUP_TIMEOUT_DURATION: Duration = Duration::from_secs(30); +pub const PORT_RELEASE_TIMEOUT: Duration = Duration::from_secs(30); +pub const LINGER_DURATION: Duration = Duration::from_secs(1); +pub const HOT_LOOP_INTERVAL: Duration = Duration::from_millis(100); +pub const DEFAULT_USERNAME: &str = "username"; +pub const DEFAULT_PASSWORD: &str = "password"; +pub const DEFAULT_DATABASE: &str = "postgres"; + +#[derive(Debug, Clone)] +pub enum ListenAddress { + Tcp(SocketAddr), + #[cfg(unix)] + Unix(PathBuf), +} + +/// Represents an ephemeral port that can be allocated and released for immediate re-use by another process. +struct EphemeralPort { + port: u16, + listener: Option, +} + +impl EphemeralPort { + /// Allocates a new ephemeral port. + /// + /// Returns a Result containing the EphemeralPort if successful, + /// or an IO error if the allocation fails. + fn allocate() -> std::io::Result { + let socket = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.set_linger(Some(LINGER_DURATION))?; + socket.bind(&std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, 0)).into())?; + socket.listen(1)?; + let listener = TcpListener::from(socket); + let port = listener.local_addr()?.port(); + Ok(EphemeralPort { + port, + listener: Some(listener), + }) + } + + /// Consumes the EphemeralPort and returns the allocated port number. + fn take(self) -> u16 { + // Drop the listener to free up the port + drop(self.listener); + + // Loop until the port is free + let start = Instant::now(); + + // If we can successfully connect to the port, it's not fully closed + while start.elapsed() < PORT_RELEASE_TIMEOUT { + let res = std::net::TcpStream::connect((Ipv4Addr::LOCALHOST, self.port)); + if res.is_err() { + // If connection fails, the port is released + break; + } + std::thread::sleep(HOT_LOOP_INTERVAL); + } + + self.port + } +} + +struct StdioReader { + output: Arc>, +} + +impl StdioReader { + fn spawn(reader: R, prefix: &'static str) -> Self { + let output = Arc::new(RwLock::new(String::new())); + let output_clone = Arc::clone(&output); + + thread::spawn(move || { + let mut buf_reader = std::io::BufReader::new(reader); + loop { + let mut line = String::new(); + match buf_reader.read_line(&mut line) { + Ok(0) => break, + Ok(_) => { + if let Ok(mut output) = output_clone.write() { + output.push_str(&line); + } + eprint!("[{}]: {}", prefix, line); + } + Err(e) => { + let error_line = format!("Error reading {}: {}\n", prefix, e); + if let Ok(mut output) = output_clone.write() { + output.push_str(&error_line); + } + eprintln!("{}", error_line); + } + } + } + }); + + StdioReader { output } + } + + fn contains(&self, s: &str) -> bool { + if let Ok(output) = self.output.read() { + output.contains(s) + } else { + false + } + } +} + +fn init_postgres(initdb: &Path, data_dir: &Path, auth: AuthType) -> std::io::Result<()> { + let mut pwfile = tempfile::NamedTempFile::new()?; + writeln!(pwfile, "{}", DEFAULT_PASSWORD)?; + let mut command = Command::new(initdb); + command + .arg("-D") + .arg(data_dir) + .arg("-A") + .arg(match auth { + AuthType::Deny => "reject", + AuthType::Trust => "trust", + AuthType::Plain => "password", + AuthType::Md5 => "md5", + AuthType::ScramSha256 => "scram-sha-256", + }) + .arg("--pwfile") + .arg(pwfile.path()) + .arg("-U") + .arg(DEFAULT_USERNAME); + + let output = command.output()?; + + let status = output.status; + let output_str = String::from_utf8_lossy(&output.stdout).to_string(); + let error_str = String::from_utf8_lossy(&output.stderr).to_string(); + + eprintln!("initdb stdout:\n{}", output_str); + eprintln!("initdb stderr:\n{}", error_str); + + if !status.success() { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "initdb command failed", + )); + } + + Ok(()) +} + +fn run_postgres( + postgres_bin: &Path, + data_dir: &Path, + socket_path: &Path, + ssl: Option<(PathBuf, PathBuf)>, + port: u16, +) -> std::io::Result { + let mut command = Command::new(postgres_bin); + command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .arg("-D") + .arg(data_dir) + .arg("-k") + .arg(socket_path) + .arg("-h") + .arg(Ipv4Addr::LOCALHOST.to_string()) + .arg("-F") + // Useful for debugging + // .arg("-d") + // .arg("5") + .arg("-p") + .arg(port.to_string()); + + if let Some((cert_path, key_path)) = ssl { + let postgres_cert_path = data_dir.join("server.crt"); + let postgres_key_path = data_dir.join("server.key"); + std::fs::copy(cert_path, &postgres_cert_path)?; + std::fs::copy(key_path, &postgres_key_path)?; + // Set permissions for the certificate and key files + std::fs::set_permissions(&postgres_cert_path, std::fs::Permissions::from_mode(0o600))?; + std::fs::set_permissions(&postgres_key_path, std::fs::Permissions::from_mode(0o600))?; + + // Edit pg_hba.conf to change all "host" line prefixes to "hostssl" + let pg_hba_path = data_dir.join("pg_hba.conf"); + let content = std::fs::read_to_string(&pg_hba_path)?; + let modified_content = content + .lines() + .filter(|line| !line.starts_with("#") && !line.is_empty()) + .map(|line| { + if line.trim_start().starts_with("host") { + line.replacen("host", "hostssl", 1) + } else { + line.to_string() + } + }) + .collect::>() + .join("\n"); + eprintln!("pg_hba.conf:\n==========\n{modified_content}\n=========="); + std::fs::write(&pg_hba_path, modified_content)?; + + command.arg("-l"); + } + + let mut child = command.spawn()?; + + let stdout_reader = BufReader::new(child.stdout.take().expect("Failed to capture stdout")); + let _ = StdioReader::spawn(stdout_reader, "stdout"); + let stderr_reader = BufReader::new(child.stderr.take().expect("Failed to capture stderr")); + let stderr_reader = StdioReader::spawn(stderr_reader, "stderr"); + + let start_time = Instant::now(); + + let mut tcp_socket: Option = None; + let mut unix_socket: Option = None; + + let unix_socket_path = get_unix_socket_path(socket_path, port); + let tcp_socket_addr = std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, port)); + let mut db_ready = false; + + while start_time.elapsed() < STARTUP_TIMEOUT_DURATION { + std::thread::sleep(HOT_LOOP_INTERVAL); + match child.try_wait() { + Ok(Some(status)) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("PostgreSQL exited with status: {}", status), + )) + } + Err(e) => return Err(e), + _ => {} + } + if !db_ready && stderr_reader.contains("database system is ready to accept connections") { + eprintln!("Database is ready"); + db_ready = true; + } else { + continue; + } + if unix_socket.is_none() { + unix_socket = std::os::unix::net::UnixStream::connect(&unix_socket_path).ok(); + } + if tcp_socket.is_none() { + tcp_socket = std::net::TcpStream::connect(tcp_socket_addr).ok(); + } + if unix_socket.is_some() && tcp_socket.is_some() { + break; + } + } + + if unix_socket.is_some() && tcp_socket.is_some() { + return Ok(child); + } + + // Print status for TCP/unix sockets + if let Some(tcp) = &tcp_socket { + eprintln!( + "TCP socket at {tcp_socket_addr:?} bound successfully on {}", + tcp.local_addr()? + ); + } else { + eprintln!("TCP socket at {tcp_socket_addr:?} binding failed"); + } + + if unix_socket.is_some() { + eprintln!("Unix socket at {unix_socket_path:?} connected successfully"); + } else { + eprintln!("Unix socket at {unix_socket_path:?} connection failed"); + } + + Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "PostgreSQL failed to start within 30 seconds", + )) +} + +fn test_data_dir() -> std::path::PathBuf { + let cargo_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../tests"); + if cargo_path.exists() { + cargo_path + } else { + Path::new("../../tests") + .canonicalize() + .expect("Failed to canonicalize tests directory path") + } +} + +fn postgres_bin_dir() -> std::io::Result { + let cargo_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../build/postgres/install/bin"); + if cargo_path.exists() { + cargo_path.canonicalize() + } else { + Path::new("../../build/postgres/install/bin").canonicalize() + } +} + +fn get_unix_socket_path(socket_path: &Path, port: u16) -> PathBuf { + socket_path.join(format!(".s.PGSQL.{}", port)) +} + +#[derive(Debug, Clone, Copy)] +pub enum Mode { + Tcp, + TcpSsl, + Unix, +} + +pub fn create_ssl_client() -> Result> { + let ssl_context = SslContext::builder(SslMethod::tls_client())?.build(); + let mut ssl = Ssl::new(&ssl_context)?; + ssl.set_connect_state(); + Ok(ssl) +} + +pub struct PostgresProcess { + child: std::process::Child, + pub socket_address: ListenAddress, + #[allow(unused)] + temp_dir: TempDir, +} + +impl Drop for PostgresProcess { + fn drop(&mut self) { + let _ = self.child.kill(); + } +} + +/// Creates and runs a new Postgres server process in a temporary directory. +pub fn setup_postgres( + auth: AuthType, + mode: Mode, +) -> Result, Box> { + let Ok(bindir) = postgres_bin_dir() else { + println!("Skipping test: postgres bin dir not found"); + return Ok(None); + }; + + let initdb = bindir.join("initdb"); + let postgres = bindir.join("postgres"); + + if !initdb.exists() || !postgres.exists() { + println!("Skipping test: initdb or postgres not found"); + return Ok(None); + } + + let temp_dir = TempDir::new()?; + let port = EphemeralPort::allocate()?; + let data_dir = temp_dir.path().join("data"); + + init_postgres(&initdb, &data_dir, auth)?; + let ssl_key = match mode { + Mode::TcpSsl => { + let certs_dir = test_data_dir().join("certs"); + let cert = certs_dir.join("server.cert.pem"); + let key = certs_dir.join("server.key.pem"); + Some((cert, key)) + } + _ => None, + }; + + let port = port.take(); + let child = run_postgres(&postgres, &data_dir, &data_dir, ssl_key, port)?; + + let socket_address = match mode { + Mode::Unix => ListenAddress::Unix(get_unix_socket_path(&data_dir, port)), + Mode::Tcp | Mode::TcpSsl => { + ListenAddress::Tcp(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port)) + } + }; + + Ok(Some(PostgresProcess { + child, + socket_address, + temp_dir, + })) +} diff --git a/rust/pgrust/Cargo.toml b/rust/pgrust/Cargo.toml index a0eaf7acf88..05c6961e53e 100644 --- a/rust/pgrust/Cargo.toml +++ b/rust/pgrust/Cargo.toml @@ -31,6 +31,7 @@ serde = "1" serde_derive = "1" percent-encoding = "2" uuid = "1" +bytemuck = { version = "1", features = ["derive"] } [dependencies.derive_more] version = "1.0.0-beta.6" @@ -38,19 +39,15 @@ features = ["full"] [dev-dependencies] tracing-subscriber.workspace = true -scopeguard = "1" +captive_postgres.workspace = true +scopeguard = "1" pretty_assertions = "1.2.0" test-log = { version = "0", features = ["trace"] } rstest = "0" clap = "4" clap_derive = "4" -tempfile = "3" -socket2 = "0.5.7" libc = "0.2.158" - -[dev-dependencies.tokio] -version = "1" -features = ["macros", "rt-multi-thread", "time", "test-util"] +hex-literal = "0.4.1" [lib] diff --git a/rust/pgrust/examples/connect.rs b/rust/pgrust/examples/connect.rs index bb26dafddcc..d23faaefbe4 100644 --- a/rust/pgrust/examples/connect.rs +++ b/rust/pgrust/examples/connect.rs @@ -1,9 +1,16 @@ +use captive_postgres::{ + setup_postgres, ListenAddress, Mode, DEFAULT_DATABASE, DEFAULT_PASSWORD, DEFAULT_USERNAME, +}; use clap::Parser; use clap_derive::Parser; +use gel_auth::AuthType; use openssl::ssl::{Ssl, SslContext, SslMethod}; use pgrust::{ - connection::{dsn::parse_postgres_dsn_env, Client, Credentials, ResolvedTarget}, - protocol::postgres::data::{DataRow, ErrorResponse, RowDescription}, + connection::{ + dsn::parse_postgres_dsn_env, Client, Credentials, ExecuteSink, Format, MaxRows, + PipelineBuilder, Portal, QuerySink, ResolvedTarget, Statement, + }, + protocol::postgres::data::{CopyData, CopyOutResponse, DataRow, ErrorResponse, RowDescription}, }; use std::net::SocketAddr; use tokio::task::LocalSet; @@ -11,6 +18,10 @@ use tokio::task::LocalSet; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + /// Use an ephemeral database + #[clap(short = 'e', long = "ephemeral", conflicts_with_all = &["dsn", "unix", "tcp", "username", "password", "database"])] + ephemeral: bool, + #[clap(short = 'D', long = "dsn", value_parser, conflicts_with_all = &["unix", "tcp", "username", "password", "database"])] dsn: Option, @@ -44,6 +55,10 @@ struct Args { )] database: String, + /// Use extended query syntax + #[clap(short = 'x', long = "extended")] + extended: bool, + /// SQL statements to run #[clap( name = "statements", @@ -54,6 +69,16 @@ struct Args { statements: Option>, } +fn address(address: &ListenAddress) -> ResolvedTarget { + match address { + ListenAddress::Tcp(addr) => ResolvedTarget::SocketAddr(*addr), + #[cfg(unix)] + ListenAddress::Unix(path) => ResolvedTarget::UnixSocketAddr( + std::os::unix::net::SocketAddr::from_pathname(path).unwrap(), + ), + } +} + #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); @@ -61,6 +86,22 @@ async fn main() -> Result<(), Box> { eprintln!("{args:?}"); let mut socket_address: Option = None; + + let _ephemeral = if args.ephemeral { + let process = setup_postgres(AuthType::Trust, Mode::Unix)?; + let Some(process) = process else { + eprintln!("Failed to start ephemeral database"); + return Err("Failed to start ephemeral database".into()); + }; + socket_address = Some(address(&process.socket_address)); + args.username = DEFAULT_USERNAME.to_string(); + args.password = DEFAULT_PASSWORD.to_string(); + args.database = DEFAULT_DATABASE.to_string(); + Some(process) + } else { + None + }; + if let Some(dsn) = args.dsn { let mut conn = parse_postgres_dsn_env(&dsn, std::env::vars())?; #[allow(deprecated)] @@ -97,16 +138,96 @@ async fn main() -> Result<(), Box> { .unwrap_or_else(|| vec!["select 1;".to_string()]); let local = LocalSet::new(); local - .run_until(run_queries(socket_address, credentials, statements)) + .run_until(run_queries( + socket_address, + credentials, + statements, + args.extended, + )) .await?; Ok(()) } +fn logging_sink() -> impl QuerySink { + ( + |rows: RowDescription<'_>| { + eprintln!("\nFields:"); + for field in rows.fields() { + eprint!(" {:?}", field.name()); + } + eprintln!(); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |row: DataRow<'_>| { + let _ = &guard; + eprintln!("Row:"); + for field in row.values() { + eprint!(" {:?}", field); + } + eprintln!(); + } + }, + |_: CopyOutResponse<'_>| { + eprintln!("\nCopy:"); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |data: CopyData<'_>| { + let _ = &guard; + eprintln!("Chunk:"); + for line in hexdump::hexdump_iter(data.data().as_ref()) { + eprintln!("{line}"); + } + } + }, + |error: ErrorResponse<'_>| { + eprintln!("\nError:\n {:?}", error); + }, + ) +} + +fn logging_sink_execute() -> impl ExecuteSink { + ( + || { + eprintln!(); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |row: DataRow<'_>| { + let _ = &guard; + eprintln!("Row:"); + for field in row.values() { + eprint!(" {:?}", field); + } + eprintln!(); + } + }, + |_: CopyOutResponse<'_>| { + eprintln!("\nCopy:"); + let guard = scopeguard::guard((), |_| { + eprintln!("Done"); + }); + move |data: CopyData<'_>| { + let _ = &guard; + eprintln!("Chunk:"); + for line in hexdump::hexdump_iter(data.data().as_ref()) { + eprintln!("{line}"); + } + } + }, + |error: ErrorResponse<'_>| { + eprintln!("\nError:\n {:?}", error); + }, + ) +} + async fn run_queries( socket_address: ResolvedTarget, credentials: Credentials, statements: Vec, + extended: bool, ) -> Result<(), Box> { let client = socket_address.connect().await?; let ssl = SslContext::builder(SslMethod::tls_client())?.build(); @@ -116,37 +237,36 @@ async fn run_queries( tokio::task::spawn_local(task); conn.ready().await?; - let local = LocalSet::new(); eprintln!("Statements: {statements:?}"); + for statement in statements { - let sink = ( - |rows: RowDescription<'_>| { - eprintln!("\nFields:"); - for field in rows.fields() { - eprint!(" {:?}", field.name()); - } - eprintln!(); - let guard = scopeguard::guard((), |_| { - eprintln!("Done"); - }); - move |row: Result, ErrorResponse<'_>>| { - let _ = &guard; - if let Ok(row) = row { - eprintln!("Row:"); - for field in row.values() { - eprint!(" {:?}", field); - } - eprintln!(); - } - } - }, - |error: ErrorResponse<'_>| { - eprintln!("\nError:\n {:?}", error); - }, - ); - local.spawn_local(conn.query(&statement, sink)); + if extended { + let conn = conn.clone(); + tokio::task::spawn_local(async move { + let pipeline = PipelineBuilder::default() + .parse(Statement::default(), &statement, &[], ()) + .describe_statement(Statement::default(), ()) + .bind( + Portal::default(), + Statement::default(), + &[], + &[Format::text()], + (), + ) + .describe_portal(Portal::default(), ()) + .execute( + Portal::default(), + MaxRows::Unlimited, + logging_sink_execute(), + ) + .build(); + conn.pipeline_sync(pipeline).await + }) + .await??; + } else { + tokio::task::spawn_local(conn.query(&statement, logging_sink())).await??; + } } - local.await; Ok(()) } diff --git a/rust/pgrust/src/connection/conn.rs b/rust/pgrust/src/connection/conn.rs index d26f1f69530..29a209f366d 100644 --- a/rust/pgrust/src/connection/conn.rs +++ b/rust/pgrust/src/connection/conn.rs @@ -1,27 +1,29 @@ use super::{ connect_raw_ssl, + flow::{MessageHandler, MessageResult, Pipeline, QuerySink}, raw_conn::RawClient, stream::{Stream, StreamWithUpgrade}, Credentials, }; use crate::{ - connection::ConnectionError, + connection::{ + flow::{QueryMessageHandler, SyncMessageHandler}, + ConnectionError, + }, handshake::ConnectionSslRequirement, protocol::{ - match_message, postgres::{ builder, - data::{ - CommandComplete, DataRow, ErrorResponse, Message, ReadyForQuery, RowDescription, - }, + data::{Message, NotificationResponse, ParameterStatus}, meta, }, StructBuffer, }, }; -use futures::FutureExt; +use futures::{future::Either, FutureExt}; use std::{ cell::RefCell, + future::ready, pin::Pin, sync::Arc, task::{ready, Poll}, @@ -35,17 +37,47 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{error, trace, warn, Level}; #[derive(Debug, thiserror::Error)] -pub enum PGError { +pub enum PGConnError { #[error("Invalid state")] InvalidState, + #[error("Postgres error: {0}")] + PgError(#[from] crate::errors::PgServerError), #[error("Connection failed: {0}")] Connection(#[from] ConnectionError), #[error("I/O error: {0}")] Io(#[from] std::io::Error), + /// If an operation in a pipeline group fails, all operations up to + /// the next sync are skipped. + #[error("Operation skipped because of previous pipeline failure: {0}")] + Skipped(crate::errors::PgServerError), #[error("Connection was closed")] Closed, } +/// A client for a PostgreSQL connection. +/// +/// ``` +/// # use pgrust::connection::*; +/// # _ = async { +/// # let config = (); +/// # let credentials = Credentials::default(); +/// # let (client, server) = ::tokio::io::duplex(64); +/// # let socket = client; +/// let (client, task) = Client::new(credentials, socket, config); +/// ::tokio::task::spawn_local(task); +/// +/// // Run a basic query +/// client.query("SELECT 1", ()).await?; +/// +/// // Run a pipelined extended query +/// client.pipeline_sync(PipelineBuilder::default() +/// .parse(Statement("stmt1"), "SELECT 1", &[], ()) +/// .bind(Portal("portal1"), Statement("stmt1"), &[], &[Format::text()], ()) +/// .execute(Portal("portal1"), MaxRows::Unlimited, ()) +/// .build()).await?; +/// # Ok::<(), PGConnError>(()) +/// # } +/// ``` pub struct Client where (B, C): StreamWithUpgrade, @@ -53,6 +85,17 @@ where conn: Rc>, } +impl Clone for Client +where + (B, C): StreamWithUpgrade, +{ + fn clone(&self) -> Self { + Self { + conn: self.conn.clone(), + } + } +} + impl Client where (B, C): StreamWithUpgrade, @@ -63,7 +106,7 @@ where credentials: Credentials, socket: B, config: C, - ) -> (Self, impl Future>) { + ) -> (Self, impl Future>) { let conn = Rc::new(PGConn::new_connection(async move { let ssl_mode = ConnectionSslRequirement::Optional; let raw = connect_raw_ssl(credentials, ssl_mode, config, socket).await?; @@ -74,107 +117,61 @@ where } /// Create a new PostgreSQL client and a background task. - pub fn new_raw(stm: RawClient) -> (Self, impl Future>) { + pub fn new_raw(stm: RawClient) -> (Self, impl Future>) { let conn = Rc::new(PGConn::new_raw(stm)); let task = conn.clone().task(); (Self { conn }, task) } - pub async fn ready(&self) -> Result<(), PGError> { + pub async fn ready(&self) -> Result<(), PGConnError> { self.conn.ready().await } + /// Performs a bare `Query` operation. The sink handles the following messages: + /// + /// - `RowDescription` + /// - `DataRow` + /// - `CopyOutResponse` + /// - `CopyData` + /// - `CopyDone` + /// - `EmptyQueryResponse` + /// - `ErrorResponse` + /// + /// `CopyInResponse` is not currently supported and will result in a `CopyFail` being + /// sent to the server. + /// + /// Cancellation safety: if the future is dropped after the first time it is polled, the operation will + /// continue to callany callbacks and run to completion. If it has not been polled, the operation will + /// not be submitted. pub fn query( &self, query: &str, f: impl QuerySink + 'static, - ) -> impl Future> { - self.conn.clone().query(query.to_owned(), f) - } -} - -struct ErasedQuerySink(Q); - -impl QuerySink for ErasedQuerySink -where - Q: QuerySink, - S: DataSink + 'static, -{ - type Output = Box; - fn error(&self, error: ErrorResponse) { - self.0.error(error) - } - fn rows(&self, rows: RowDescription) -> Self::Output { - Box::new(self.0.rows(rows)) - } -} - -pub trait QuerySink { - type Output: DataSink; - fn rows(&self, rows: RowDescription) -> Self::Output; - fn error(&self, error: ErrorResponse); -} - -impl QuerySink for Box -where - Q: QuerySink + 'static, - S: DataSink + 'static, -{ - type Output = Box; - fn rows(&self, rows: RowDescription) -> Self::Output { - Box::new(self.as_ref().rows(rows)) - } - fn error(&self, error: ErrorResponse) { - self.as_ref().error(error) - } -} - -impl QuerySink for (F1, F2) -where - F1: for<'a> Fn(RowDescription) -> S, - F2: for<'a> Fn(ErrorResponse), - S: DataSink, -{ - type Output = S; - fn rows(&self, rows: RowDescription) -> S { - (self.0)(rows) - } - fn error(&self, error: ErrorResponse) { - (self.1)(error) - } -} - -pub trait DataSink { - fn row(&self, values: Result); -} - -impl DataSink for () { - fn row(&self, _: Result) {} -} - -impl DataSink for F -where - F: for<'a> Fn(Result, ErrorResponse<'a>>), -{ - fn row(&self, values: Result) { - (self)(values) + ) -> impl Future> { + match self.conn.clone().query(query, f) { + Ok(f) => Either::Left(f), + Err(e) => Either::Right(ready(Err(e))), + } } -} -impl DataSink for Box { - fn row(&self, values: Result) { - self.as_ref().row(values) + /// Performs a set of pipelined steps as a `Sync` group. + /// + /// Cancellation safety: if the future is dropped after the first time it is polled, the operation will + /// continue to callany callbacks and run to completion. If it has not been polled, the operation will + /// not be submitted. + pub fn pipeline_sync( + &self, + pipeline: Pipeline, + ) -> impl Future> { + match self.conn.clone().pipeline_sync(pipeline) { + Ok(f) => Either::Left(f), + Err(e) => Either::Right(ready(Err(e))), + } } } -struct QueryWaiter { - #[allow(unused)] - tx: tokio::sync::mpsc::UnboundedSender<()>, - f: Box>>, - data: RefCell>>, -} - #[derive(derive_more::Debug)] +#[allow(clippy::type_complexity)] enum ConnState where (B, C): StreamWithUpgrade, @@ -183,8 +180,14 @@ where #[allow(clippy::type_complexity)] Connecting(Pin, ConnectionError>>>>), #[debug("Ready(..)")] - Ready(RawClient, VecDeque), - Error(PGError), + Ready { + client: RawClient, + handlers: VecDeque<( + Box, + Option>, + )>, + }, + Error(PGConnError), Closed, } @@ -193,33 +196,39 @@ where (B, C): StreamWithUpgrade, { state: RefCell>, - write_lock: tokio::sync::Mutex<()>, + queue: RefCell>>, ready_lock: Arc>, } impl PGConn where (B, C): StreamWithUpgrade, + B: 'static, + C: 'static, { pub fn new_connection( future: impl Future, ConnectionError>> + 'static, ) -> Self { Self { state: ConnState::Connecting(future.boxed_local()).into(), - write_lock: Default::default(), + queue: Default::default(), ready_lock: Default::default(), } } pub fn new_raw(stm: RawClient) -> Self { Self { - state: ConnState::Ready(stm, Default::default()).into(), - write_lock: Default::default(), + state: ConnState::Ready { + client: stm, + handlers: Default::default(), + } + .into(), + queue: Default::default(), ready_lock: Default::default(), } } - fn check_error(&self) -> Result<(), PGError> { + fn check_error(&self) -> Result<(), PGConnError> { let state = &mut *self.state.borrow_mut(); match state { ConnState::Error(..) => { @@ -229,97 +238,146 @@ where error!("Connection failed: {e:?}"); Err(e) } - ConnState::Closed => Err(PGError::Closed), + ConnState::Closed => Err(PGConnError::Closed), _ => Ok(()), } } #[inline(always)] - async fn ready(&self) -> Result<(), PGError> { + async fn ready(&self) -> Result<(), PGConnError> { let _ = self.ready_lock.lock().await; self.check_error() } - fn with_stream(&self, f: F) -> Result + fn with_stream(&self, f: F) -> Result where F: FnOnce(Pin<&mut RawClient>) -> T, { match &mut *self.state.borrow_mut() { - ConnState::Ready(ref mut raw_client, _) => Ok(f(Pin::new(raw_client))), - _ => Err(PGError::InvalidState), + ConnState::Ready { ref mut client, .. } => Ok(f(Pin::new(client))), + _ => Err(PGConnError::InvalidState), } } - async fn write(&self, mut buf: &[u8]) -> Result<(), PGError> { - let _lock = self.write_lock.lock().await; + fn write( + self: Rc, + message_handlers: Vec>, + buf: Vec, + ) -> Result, PGConnError> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.clone().queue.borrow_mut().submit(async move { + // If the future was dropped before the first poll, we don't submit the operation + if tx.is_closed() { + return Ok(()); + } - if buf.is_empty() { - return Ok(()); - } - if tracing::enabled!(Level::TRACE) { - trace!("Write:"); - for s in hexdump::hexdump_iter(buf) { - trace!("{}", s); + // Once we're polled the first time, we can add the handlers + match &mut *self.state.borrow_mut() { + ConnState::Ready { handlers, .. } => { + let mut handlers_iter = message_handlers.into_iter(); + let mut tx = Some(tx); + while let Some(handler) = handlers_iter.next() { + if handlers_iter.len() == 0 { + handlers.push_back((handler, tx.take())); + } else { + handlers.push_back((handler, None)); + } + } + } + x => { + warn!("Connection state was not ready: {x:?}"); + return Err(PGConnError::InvalidState); + } } - } - loop { - let n = poll_fn(|cx| { - self.with_stream(|stm| { - let n = match ready!(stm.poll_write(cx, buf)) { - Ok(n) => n, - Err(e) => return Poll::Ready(Err(PGError::Io(e))), - }; - Poll::Ready(Ok(n)) - })? - }) - .await?; - if n == buf.len() { - break; + + if tracing::enabled!(Level::TRACE) { + trace!("Write:"); + for s in hexdump::hexdump_iter(&buf) { + trace!("{}", s); + } } - buf = &buf[n..]; - } - Ok(()) + + let mut buf = &buf[..]; + + loop { + let n = poll_fn(|cx| { + self.with_stream(|stm| { + let n = match ready!(stm.poll_write(cx, buf)) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(PGConnError::Io(e))), + }; + Poll::Ready(Ok(n)) + })? + }) + .await?; + if n == buf.len() { + break; + } + buf = &buf[n..]; + } + + Ok(()) + }); + + Ok(rx) } - fn process_message(&self, message: Option) -> Result<(), PGError> { + fn process_message(&self, message: Option) -> Result<(), PGConnError> { let state = &mut *self.state.borrow_mut(); match state { - ConnState::Ready(_, queue) => { - let message = message.ok_or(PGError::InvalidState); - match_message!(Ok(message?), Backend { - (RowDescription as row) => { - if let Some(qw) = queue.back() { - let qs = qw.f.rows(row); - *qw.data.borrow_mut() = Some(qs); - } - }, - (DataRow as row) => { - if let Some(qw) = queue.back() { - if let Some(qs) = &*qw.data.borrow() { - qs.row(Ok(row)) + ConnState::Ready { handlers, .. } => { + let message = message.ok_or(PGConnError::InvalidState)?; + if NotificationResponse::try_new(&message).is_some() { + warn!("Notification: {:?}", message); + return Ok(()); + } + if ParameterStatus::try_new(&message).is_some() { + warn!("ParameterStatus: {:?}", message); + return Ok(()); + } + if let Some((handler, _tx)) = handlers.front_mut() { + match handler.handle(message) { + MessageResult::SkipUntilSync => { + let mut found_sync = false; + let name = handler.name(); + while let Some((handler, _)) = handlers.front() { + if handler.is_sync() { + found_sync = true; + break; + } + trace!("skipping {}", handler.name()); + handlers.pop_front(); + } + if !found_sync { + warn!("Unexpected state in {name}: No sync handler found"); } } - }, - (CommandComplete) => { - if let Some(qw) = queue.back() { - *qw.data.borrow_mut() = None; + MessageResult::Continue => {} + MessageResult::Done => { + handlers.pop_front(); } - }, - (ReadyForQuery) => { - queue.pop_front(); - }, - (ErrorResponse as err) => { - if let Some(qw) = queue.back() { - qw.f.error(err); + MessageResult::Unknown => { + // TODO: Should the be exposed to the API consumer? + warn!( + "Unknown message in {} ({:?})", + handler.name(), + message.mtype() as char + ); } - }, - unknown => { - eprintln!("Unknown message: {unknown:?}"); - } - }); + MessageResult::UnexpectedState { complaint } => { + // TODO: Should the be exposed to the API consumer? + warn!( + "Unexpected state in {} while handling message ({:?}): {complaint}", + handler.name(), + message.mtype() as char + ); + } + }; + }; } ConnState::Connecting(..) => { - return Err(PGError::InvalidState); + return Err(PGConnError::InvalidState); } ConnState::Error(..) | ConnState::Closed => self.check_error()?, } @@ -327,9 +385,8 @@ where Ok(()) } - pub fn task(self: Rc) -> impl Future> { + pub fn task(self: Rc) -> impl Future> { let ready_lock = self.ready_lock.clone().try_lock_owned().unwrap(); - async move { poll_fn(|cx| { let mut state = self.state.borrow_mut(); @@ -339,17 +396,20 @@ where let raw = match result { Ok(raw) => raw, Err(e) => { - let error = PGError::Connection(e); + let error = PGConnError::Connection(e); *state = ConnState::Error(error); - return Poll::Ready(Ok::<_, PGError>(())); + return Poll::Ready(Ok::<_, PGConnError>(())); } }; - *state = ConnState::Ready(raw, VecDeque::new()); - Poll::Ready(Ok::<_, PGError>(())) + *state = ConnState::Ready { + client: raw, + handlers: Default::default(), + }; + Poll::Ready(Ok::<_, PGConnError>(())) } Poll::Pending => Poll::Pending, }, - ConnState::Ready(..) => Poll::Ready(Ok(())), + ConnState::Ready { .. } => Poll::Ready(Ok(())), ConnState::Error(..) | ConnState::Closed => Poll::Ready(self.check_error()), } }) @@ -361,10 +421,15 @@ where loop { let mut read_buffer = [0; 1024]; let n = poll_fn(|cx| { + // Poll the queue before we poll the read stream. Note that we toss + // the result here. Either we'll make progress or there's nothing to + // do. + while self.queue.borrow_mut().poll_next_unpin(cx).is_ready() {} + self.with_stream(|stm| { let mut buf = ReadBuf::new(&mut read_buffer); let res = ready!(stm.poll_read(cx, &mut buf)); - Poll::Ready(res.map(|_| buf.filled().len())).map_err(PGError::Io) + Poll::Ready(res.map(|_| buf.filled().len())).map_err(PGConnError::Io) })? }) .await?; @@ -377,6 +442,14 @@ where } buffer.push_fallible(&read_buffer[..n], |message| { + if let Ok(message) = &message { + if tracing::enabled!(Level::TRACE) { + trace!("Message ({:?})", message.mtype() as char); + for s in hexdump::hexdump_iter(message.__buf) { + trace!("{}", s); + } + } + }; self.process_message(Some(message.map_err(ConnectionError::ParseError)?)) })?; @@ -388,35 +461,526 @@ where } } - pub async fn query( + pub fn query( self: Rc, - query: String, + query: &str, f: impl QuerySink + 'static, - ) -> Result<(), PGError> { + ) -> Result>, PGConnError> { trace!("Query task started: {query}"); - let mut rx = match &mut *self.state.borrow_mut() { - ConnState::Ready(_, queue) => { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let f = Box::new(ErasedQuerySink(f)) as _; - queue.push_back(QueryWaiter { - tx, - f, - data: None.into(), - }); - rx + let message = builder::Query { query }.to_vec(); + let rx = self.write( + vec![Box::new(QueryMessageHandler { + sink: f, + data: None, + copy: None, + })], + message, + )?; + Ok(async { + _ = rx.await; + Ok(()) + }) + } + + pub fn pipeline_sync( + self: Rc, + pipeline: Pipeline, + ) -> Result>, PGConnError> { + trace!("Pipeline task started"); + let Pipeline { + mut messages, + mut handlers, + } = pipeline; + handlers.push(Box::new(SyncMessageHandler)); + messages.extend_from_slice(&builder::Sync::default().to_vec()); + + let rx = self.write(handlers, messages)?; + Ok(async { + _ = rx.await; + Ok(()) + }) + } +} + +#[cfg(test)] +mod tests { + use hex_literal::hex; + use std::{fmt::Write, time::Duration}; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt, DuplexStream}, + task::LocalSet, + time::timeout, + }; + + use crate::connection::{ + flow::{CopyDataSink, DataSink, DoneHandling}, + raw_conn::ConnectionParams, + }; + use crate::protocol::postgres::data::*; + + use super::*; + + impl QuerySink for Rc> { + type Output = Self; + type CopyOutput = Self; + fn rows(&mut self, rows: RowDescription) -> Self { + write!(self.borrow_mut(), "[table=[").unwrap(); + for field in rows.fields() { + write!(self.borrow_mut(), "{},", field.name().to_string_lossy()).unwrap(); + } + write!(self.borrow_mut(), "]").unwrap(); + self.clone() + } + fn copy(&mut self, copy: CopyOutResponse) -> Self { + write!( + self.borrow_mut(), + "[copy={:?} {:?}", + copy.format(), + copy.format_codes() + ) + .unwrap(); + self.clone() + } + fn error(&mut self, error: ErrorResponse) { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]", + field.value().to_string_lossy() + ) + .unwrap(); + return; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]", error).unwrap(); + } + } + + impl DataSink for Rc> { + fn row(&mut self, row: DataRow) { + write!(self.borrow_mut(), "[").unwrap(); + for value in row.values() { + write!(self.borrow_mut(), "{},", value.to_string_lossy()).unwrap(); + } + write!(self.borrow_mut(), "]").unwrap(); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => { + write!( + self.borrow_mut(), + " done={}]", + complete.tag().to_string_lossy() + ) + .unwrap(); + } + Err(error) => { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]]", + field.value().to_string_lossy() + ) + .unwrap(); + return DoneHandling::Handled; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]]", error).unwrap(); + } } - x => { - warn!("Connection state was not ready: {x:?}"); - return Err(PGError::InvalidState); + DoneHandling::Handled + } + } + + impl CopyDataSink for Rc> { + fn data(&mut self, data: CopyData) { + write!( + self.borrow_mut(), + "[{}]", + String::from_utf8_lossy(data.data().as_ref()) + ) + .unwrap(); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => { + write!( + self.borrow_mut(), + " done={}]", + complete.tag().to_string_lossy() + ) + .unwrap(); + } + Err(error) => { + for field in error.fields() { + if field.etype() as char == 'C' { + write!( + self.borrow_mut(), + "[error {}]]", + field.value().to_string_lossy() + ) + .unwrap(); + return DoneHandling::Handled; + } + } + write!(self.borrow_mut(), "[error ??? {:?}]]", error).unwrap(); + } + } + DoneHandling::Handled + } + } + + async fn read_expect(stream: &mut S, expected: &[u8]) { + let mut buf = vec![0u8; expected.len()]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, expected); + } + + /// Perform a test using captured binary protocol data from a real server. + async fn run_expect( + query_task: impl FnOnce(Client, Rc>) -> F + 'static, + expect: &'static [(&[u8], &[u8], &str)], + ) { + let f = async move { + let (mut s1, s2) = tokio::io::duplex(1024 * 1024); + + let (client, task) = Client::new_raw(RawClient::new(s2, ConnectionParams::default())); + let task_handle = tokio::task::spawn_local(task); + + let handle = tokio::task::spawn_local(async move { + let log = Rc::new(RefCell::new(String::new())); + query_task(client, log.clone()).await; + Rc::try_unwrap(log).unwrap().into_inner() + }); + + let mut log_expect = String::new(); + for (read, write, expect) in expect { + // Query[text=""] + eprintln!("read {read:?}"); + read_expect(&mut s1, read).await; + eprintln!("write {write:?}"); + s1.write_all(write).await.unwrap(); + log_expect.push_str(expect); } + + let log = handle.await.unwrap(); + + assert_eq!(log, log_expect); + + // EOF to trigger the task to exit + drop(s1); + + task_handle.await.unwrap().unwrap(); }; - let message = builder::Query { query: &query }.to_vec(); - self.write(&message).await?; - rx.recv().await; - Ok(()) + let local = LocalSet::new(); + let task = local.spawn_local(f); + + timeout(Duration::from_secs(1), local).await.unwrap(); + + // Ensure we detect panics inside the task + task.await.unwrap(); } -} -#[cfg(test)] -mod tests {} + #[test_log::test(tokio::test)] + async fn query_select_1() { + run_expect( + |client, log| async move { + client.query("SELECT 1", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 0d53454c 45435420 3100"), + // T, D, C, Z + &hex!("54000000 2100013f 636f6c75 6d6e3f00 00000000 00000000 00170004 ffffffff 00004400 00000b00 01000000 01314300 00000d53 454c4543 54203100 5a000000 0549"), + "[table=[?column?,][1,] done=SELECT 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_select_1_limit_0() { + run_expect( + |client, log| async move { + client.query("SELECT 1 LIMIT 0", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 1553454c 45435420 31204c49 4d495420 3000"), + // T, C, Z + &hex!("54000000 2100013f 636f6c75 6d6e3f00 00000000 00000000 00170004 ffffffff 00004300 00000d53 454c4543 54203000 5a000000 0549"), + "[table=[?column?,] done=SELECT 0]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_1() { + run_expect( + |client, log| async move { + client.query("copy (select 1) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 1f636f70 79202873 656c6563 74203129 20746f20 7374646f 75743b00"), + // H, d, c, C, Z + &hex!("48000000 09000001 00006400 00000631 0a630000 00044300 00000b43 4f505920 31005a00 00000549"), + "[copy=0 [0][1\n] done=COPY 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_1_limit_0() { + run_expect( + |client, log| async move { + client.query("copy (select 1 limit 0) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 27636f70 79202873 656c6563 74203120 6c696d69 74203029 20746f20 7374646f 75743b00"), + // H, c, C, Z + &hex!("48000000 09000001 00006300 00000443 0000000b 434f5059 2030005a 00000005 49"), + "[copy=0 [0] done=COPY 0]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_copy_with_error_rows() { + run_expect( + |client, log| async move { + client.query("copy (select case when id = 2 then id/(id-2) else id end from (select generate_series(1,2) as id)) to stdout;", log.clone()).await.unwrap(); + }, + &[( + &hex!(""" + 51000000 72636f70 79202873 656c6563 + 74206361 73652077 68656e20 6964203d + 20322074 68656e20 69642f28 69642d32 + 2920656c 73652069 6420656e 64206672 + 6f6d2028 73656c65 63742067 656e6572 + 6174655f 73657269 65732831 2c322920 + 61732069 64292920 746f2073 74646f75 + 743b00 + """), + // H, d, E, Z + &hex!(""" + 48000000 09000001 00006400 00000631 + 0a450000 00415345 52524f52 00564552 + 524f5200 43323230 3132004d 64697669 + 73696f6e 20627920 7a65726f 0046696e + 742e6300 4c383431 0052696e 74346469 + 7600005a 00000005 49 + """), + "[copy=0 [0][1\n][error 22012]]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_error() { + run_expect( + |client, log| async move { + client.query("do $$begin raise exception 'hi'; end$$;", log.clone()).await.unwrap(); + }, + &[( + &hex!("51000000 2c646f20 24246265 67696e20 72616973 65206578 63657074 696f6e20 27686927 3b20656e 6424243b 00"), + // E, Z + &hex!(""" + 45000000 75534552 524f5200 56455252 + 4f520043 50303030 31004d68 69005750 + 4c2f7067 53514c20 66756e63 74696f6e + 20696e6c 696e655f 636f6465 5f626c6f + 636b206c 696e6520 31206174 20524149 + 53450046 706c5f65 7865632e 63004c33 + 39313100 52657865 635f7374 6d745f72 + 61697365 00005a00 00000549 + """), + "[error P0001]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_empty_do() { + run_expect( + |client, log| async move { + client + .query("do $$begin end$$;", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!("51000000 16646f20 24246265 67696e20 656e6424 243b00"), + // C, Z + &hex!(""" + 43000000 07444f00 5a000000 0549 + """), + "", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_error_with_rows() { + run_expect( + |client, log| async move { + client.query("select case when id = 2 then id/(id-2) else 1 end from (select 1 as id union all select 2 as id);", log.clone()).await.unwrap(); + }, + &[( + &hex!(""" + 51000000 6673656c 65637420 63617365 + 20776865 6e206964 203d2032 20746865 + 6e206964 2f286964 2d322920 656c7365 + 20312065 6e642066 726f6d20 2873656c + 65637420 31206173 20696420 756e696f + 6e20616c 6c207365 6c656374 20322061 + 73206964 293b00 + """), + // T, D, E, Z + &hex!(""" + 54000000 1d000163 61736500 00000000 + 00000000 00170004 ffffffff 00004400 + 00000b00 01000000 01314500 00004153 + 4552524f 52005645 52524f52 00433232 + 30313200 4d646976 6973696f 6e206279 + 207a6572 6f004669 6e742e63 004c3834 + 31005269 6e743464 69760000 5a000000 + 0549 + """), + "[table=[case,][1,][error 22012]]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_second_errors() { + run_expect( + |client, log| async move { + client + .query("select; select 1/0;", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!("51000000 1873656c 6563743b 2073656c 65637420 312f303b 00"), + // T, D, C, E, Z + &hex!(""" + 54000000 06000044 00000006 00004300 + 00000d53 454c4543 54203100 45000000 + 41534552 524f5200 56455252 4f520043 + 32323031 32004d64 69766973 696f6e20 + 6279207a 65726f00 46696e74 2e63004c + 38343100 52696e74 34646976 00005a00 + 00000549 + """), + "[table=[][] done=SELECT 1][error 22012]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_notification() { + run_expect( + |client, log| async move { + client + .query("listen a; select pg_notify('a','b')", log.clone()) + .await + .unwrap(); + }, + &[( + &hex!( + " + 51000000 286c6973 74656e20 613b2073 + 656c6563 74207067 5f6e6f74 69667928 + 2761272c 27622729 00 + " + ), + // C, T, D, C, A, Z + &hex!( + " + 43000000 0b4c4953 54454e00 54000000 + 22000170 675f6e6f 74696679 00000000 + 00000000 0008e600 04ffffff ff000044 + 0000000a 00010000 00004300 00000d53 + 454c4543 54203100 41000000 0c002cba + 5f610062 005a0000 000549 + " + ), + "[table=[pg_notify,][,] done=SELECT 1]", + )], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_two_empty() { + run_expect( + |client, log| async move { + client.query("", log.clone()).await.unwrap(); + client.query("", log.clone()).await.unwrap(); + }, + &[ + ( + &hex!("51000000 0500"), + // I, Z + &hex!("49000000 045a0000 000549"), + "", + ), + ( + &hex!("51000000 0500"), + // I, Z + &hex!("49000000 045a0000 000549"), + "", + ), + ], + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn query_two_error() { + run_expect( + |client, log| async move { + client.query(".", log.clone()).await.unwrap(); + client.query(".", log.clone()).await.unwrap(); + }, + &[ + ( + &hex!("51000000 062e00"), + // E, Z + &hex!(""" + 45000000 59534552 524f5200 56455252 + 4f520043 34323630 31004d73 796e7461 + 78206572 726f7220 6174206f 72206e65 + 61722022 2e220050 31004673 63616e2e + 6c004c31 32343400 52736361 6e6e6572 + 5f797965 72726f72 00005a00 00000549 + """), + "[error 42601]", + ), + ( + &hex!("51000000 062e00"), + // E, Z + &hex!(""" + 45000000 59534552 524f5200 56455252 + 4f520043 34323630 31004d73 796e7461 + 78206572 726f7220 6174206f 72206e65 + 61722022 2e220050 31004673 63616e2e + 6c004c31 32343400 52736361 6e6e6572 + 5f797965 72726f72 00005a00 00000549 + """), + "[error 42601]", + ), + ], + ) + .await; + } +} diff --git a/rust/pgrust/src/connection/flow.rs b/rust/pgrust/src/connection/flow.rs new file mode 100644 index 00000000000..15b3628b175 --- /dev/null +++ b/rust/pgrust/src/connection/flow.rs @@ -0,0 +1,1234 @@ +//! Postgres flow notes: +//! +//! https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-PIPELINING +//! +//! https://segmentfault.com/a/1190000017136059 +//! +//! Extended query messages Parse, Bind, Describe, Execute, Close put the server +//! into a "skip-til-sync" mode when erroring. All messages other than Terminate (including +//! those not part of the extended query protocol) are skipped until an explicit Sync message is received. +//! +//! Sync closes _implicit_ but not _explicit_ transactions. +//! +//! Both Query and Execute may return COPY responses rather than rows. In the case of Query, +//! RowDescription + DataRow is replaced by CopyOutResponse + CopyData + CopyDone. In the case +//! of Execute, describing the portal will return NoData, but Execute will return CopyOutResponse + +//! CopyData + CopyDone. + +use std::{cell::RefCell, num::NonZeroU32, rc::Rc}; + +use crate::protocol::{ + match_message, + postgres::{ + builder, + data::{ + BindComplete, CloseComplete, CommandComplete, CopyData, CopyDone, CopyOutResponse, + DataRow, EmptyQueryResponse, ErrorResponse, Message, NoData, NoticeResponse, + ParameterDescription, ParseComplete, PortalSuspended, ReadyForQuery, RowDescription, + }, + }, + Encoded, +}; + +#[derive(Debug, Clone, Copy)] +pub enum Param<'a> { + Null, + Text(&'a str), + Binary(&'a [u8]), +} + +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(transparent)] +pub struct Oid(u32); + +impl Oid { + pub fn unspecified() -> Self { + Self(0) + } + + pub fn from(oid: NonZeroU32) -> Self { + Self(oid.get()) + } +} + +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(transparent)] +pub struct Format(i16); + +impl Format { + pub fn text() -> Self { + Self(0) + } + + pub fn binary() -> Self { + Self(1) + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(i32)] +pub enum MaxRows { + Unlimited, + Limited(NonZeroU32), +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Portal<'a>(pub &'a str); + +#[derive(Debug, Clone, Copy, Default)] +pub struct Statement<'a>(pub &'a str); + +pub trait Flow { + fn to_vec(&self) -> Vec; +} + +/// Performs a prepared statement parse operation. +/// +/// Handles: +/// - `ParseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ParseFlow<'a> { + pub name: Statement<'a>, + pub query: &'a str, + pub param_types: &'a [Oid], +} + +/// Performs a prepared statement bind operation. +/// +/// Handles: +/// - `BindComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct BindFlow<'a> { + pub portal: Portal<'a>, + pub statement: Statement<'a>, + pub params: &'a [Param<'a>], + pub result_format_codes: &'a [Format], +} + +/// Performs a prepared statement execute operation. +/// +/// Handles: +/// - `CommandComplete` +/// - `DataRow` +/// - `PortalSuspended` +/// - `CopyOutResponse` +/// - `CopyData` +/// - `CopyDone` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ExecuteFlow<'a> { + pub portal: Portal<'a>, + pub max_rows: MaxRows, +} + +/// Performs a portal describe operation. +/// +/// Handles: +/// - `RowDescription` +/// - `NoData` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct DescribePortalFlow<'a> { + pub name: Portal<'a>, +} + +/// Performs a statement describe operation. +/// +/// Handles: +/// - `RowDescription` +/// - `NoData` +/// - `ParameterDescription` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct DescribeStatementFlow<'a> { + pub name: Statement<'a>, +} + +/// Performs a portal close operation. +/// +/// Handles: +/// - `CloseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct ClosePortalFlow<'a> { + pub name: Portal<'a>, +} + +/// Performs a statement close operation. +/// +/// Handles: +/// - `CloseComplete` +/// - `ErrorResponse` +#[derive(Debug, Clone, Copy)] +struct CloseStatementFlow<'a> { + pub name: Statement<'a>, +} + +/// Performs a query operation. +/// +/// Handles: +/// - `EmptyQueryResponse`: If no queries were specified in the text +/// - `CommandComplete`: For each fully-completed query +/// - `RowDescription`: For each query that returns data +/// - `DataRow`: For each row returned by a query +/// - `CopyOutResponse`: For each query that returns copy data +/// - `CopyData`: For each chunk of copy data returned by a query +/// - `CopyDone`: For each query that returns copy data +/// - `ErrorResponse`: For the first failed query +#[derive(Debug, Clone, Copy)] +struct QueryFlow<'a> { + pub query: &'a str, +} + +impl<'a> Flow for ParseFlow<'a> { + fn to_vec(&self) -> Vec { + let param_types = bytemuck::cast_slice(self.param_types); + builder::Parse { + statement: self.name.0, + query: self.query, + param_types, + } + .to_vec() + } +} + +impl<'a> Flow for BindFlow<'a> { + fn to_vec(&self) -> Vec { + let mut format_codes = Vec::with_capacity(self.params.len()); + let mut values = Vec::with_capacity(self.params.len()); + + for param in self.params { + match param { + Param::Null => { + format_codes.push(0); + values.push(Encoded::Null); + } + Param::Text(value) => { + format_codes.push(0); + values.push(Encoded::Value(value.as_bytes())); + } + Param::Binary(value) => { + format_codes.push(1); + values.push(Encoded::Value(value)); + } + } + } + + let result_format_codes = bytemuck::cast_slice(self.result_format_codes); + + builder::Bind { + portal: self.portal.0, + statement: self.statement.0, + format_codes: &format_codes, + values: &values, + result_format_codes, + } + .to_vec() + } +} + +impl<'a> Flow for ExecuteFlow<'a> { + fn to_vec(&self) -> Vec { + let max_rows = match self.max_rows { + MaxRows::Unlimited => 0, + MaxRows::Limited(n) => n.get() as i32, + }; + builder::Execute { + portal: self.portal.0, + max_rows, + } + .to_vec() + } +} + +impl<'a> Flow for DescribePortalFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Describe { + name: self.name.0, + dtype: b'P', + } + .to_vec() + } +} + +impl<'a> Flow for DescribeStatementFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Describe { + name: self.name.0, + dtype: b'S', + } + .to_vec() + } +} + +impl<'a> Flow for ClosePortalFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Close { + name: self.name.0, + ctype: b'P', + } + .to_vec() + } +} + +impl<'a> Flow for CloseStatementFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Close { + name: self.name.0, + ctype: b'S', + } + .to_vec() + } +} + +impl<'a> Flow for QueryFlow<'a> { + fn to_vec(&self) -> Vec { + builder::Query { query: self.query }.to_vec() + } +} + +pub(crate) enum MessageResult { + Continue, + Done, + SkipUntilSync, + Unknown, + UnexpectedState { complaint: &'static str }, +} + +pub(crate) trait MessageHandler { + fn handle(&mut self, message: Message) -> MessageResult; + fn name(&self) -> &'static str; + fn is_sync(&self) -> bool { + false + } +} + +pub(crate) struct SyncMessageHandler; + +impl MessageHandler for SyncMessageHandler { + fn handle(&mut self, message: Message) -> MessageResult { + if ReadyForQuery::try_new(&message).is_some() { + return MessageResult::Done; + } + MessageResult::Unknown + } + fn name(&self) -> &'static str { + "Sync" + } + fn is_sync(&self) -> bool { + true + } +} + +impl MessageHandler for (&'static str, F) +where + F: for<'a> FnMut(Message<'a>) -> MessageResult, +{ + fn handle(&mut self, message: Message) -> MessageResult { + (self.1)(message) + } + fn name(&self) -> &'static str { + self.0 + } +} + +pub trait FlowWithSink { + fn visit_flow(&self, f: impl FnMut(&dyn Flow)); + fn make_handler(self) -> Box; +} + +pub trait SimpleFlowSink { + fn handle(&mut self, result: Result<(), ErrorResponse>); +} + +impl SimpleFlowSink for () { + fn handle(&mut self, _: Result<(), ErrorResponse>) {} +} + +impl FnMut(Result<(), ErrorResponse>)> SimpleFlowSink for F { + fn handle(&mut self, result: Result<(), ErrorResponse>) { + (self)(result) + } +} + +impl FlowWithSink for (ParseFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("Parse", move |message: Message<'_>| { + if ParseComplete::try_new(&message).is_some() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Some(msg) = ErrorResponse::try_new(&message) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (BindFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("Bind", move |message: Message<'_>| { + if BindComplete::try_new(&message).is_some() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Some(msg) = ErrorResponse::try_new(&message) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (ClosePortalFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("ClosePortal", move |message: Message<'_>| { + if CloseComplete::try_new(&message).is_some() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Some(msg) = ErrorResponse::try_new(&message) { + self.1.handle(Err(msg)); + return MessageResult::SkipUntilSync; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (CloseStatementFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(mut self) -> Box { + Box::new(("CloseStatement", move |message: Message<'_>| { + if CloseComplete::try_new(&message).is_some() { + self.1.handle(Ok(())); + return MessageResult::Done; + } + if let Some(msg) = ErrorResponse::try_new(&message) { + self.1.handle(Err(msg)); + return MessageResult::Done; + } + MessageResult::Unknown + })) + } +} + +impl FlowWithSink for (ExecuteFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(ExecuteMessageHandler { + sink: self.1, + data: None, + copy: None, + }) + } +} + +impl FlowWithSink for (QueryFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(QueryMessageHandler { + sink: self.1, + data: None, + copy: None, + }) + } +} + +impl FlowWithSink for (DescribePortalFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(DescribeMessageHandler { sink: self.1 }) + } +} + +impl FlowWithSink for (DescribeStatementFlow<'_>, S) { + fn visit_flow(&self, mut f: impl FnMut(&dyn Flow)) { + f(&self.0); + } + fn make_handler(self) -> Box { + Box::new(DescribeMessageHandler { sink: self.1 }) + } +} + +pub trait DescribeSink { + fn params(&mut self, params: ParameterDescription); + fn rows(&mut self, rows: RowDescription); + fn error(&mut self, error: ErrorResponse); +} + +impl DescribeSink for () { + fn params(&mut self, _: ParameterDescription) {} + fn rows(&mut self, _: RowDescription) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl DescribeSink for F +where + F: for<'a> FnMut(RowDescription<'a>), +{ + fn rows(&mut self, rows: RowDescription) { + (self)(rows) + } + fn params(&mut self, _params: ParameterDescription) {} + fn error(&mut self, _error: ErrorResponse) {} +} + +impl DescribeSink for (F1, F2) +where + F1: for<'a> FnMut(ParameterDescription<'a>), + F2: for<'a> FnMut(RowDescription<'a>), +{ + fn params(&mut self, params: ParameterDescription) { + (self.0)(params) + } + fn rows(&mut self, rows: RowDescription) { + (self.1)(rows) + } + fn error(&mut self, _error: ErrorResponse) {} +} + +struct DescribeMessageHandler { + sink: S, +} + +impl MessageHandler for DescribeMessageHandler { + fn name(&self) -> &'static str { + "Describe" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (ParameterDescription as params) => { + self.sink.params(params); + return MessageResult::Continue; + }, + (RowDescription as rows) => { + self.sink.rows(rows); + return MessageResult::Done; + }, + (NoData) => { + return MessageResult::Done; + }, + (ErrorResponse as err) => { + self.sink.error(err); + return MessageResult::SkipUntilSync; + }, + _unknown => { + return MessageResult::Unknown; + } + }) + } +} + +pub trait ExecuteSink { + type Output: ExecuteDataSink; + type CopyOutput: CopyDataSink; + + fn rows(&mut self) -> Self::Output; + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput; + fn complete(&mut self, _complete: ExecuteCompletion) {} + fn notice(&mut self, _: NoticeResponse) {} + fn error(&mut self, error: ErrorResponse); +} + +pub enum ExecuteCompletion<'a> { + PortalSuspended(PortalSuspended<'a>), + CommandComplete(CommandComplete<'a>), +} + +impl ExecuteSink for () { + type Output = (); + type CopyOutput = (); + fn rows(&mut self) {} + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl ExecuteSink for (F1, F2) +where + F1: for<'a> FnMut() -> S, + F2: for<'a> FnMut(ErrorResponse<'a>), + S: ExecuteDataSink, +{ + type Output = S; + type CopyOutput = (); + fn rows(&mut self) -> S { + (self.0)() + } + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, error: ErrorResponse) { + (self.1)(error) + } +} + +impl ExecuteSink for (F1, F2, F3) +where + F1: for<'a> FnMut() -> S, + F2: for<'a> FnMut(CopyOutResponse<'a>) -> T, + F3: for<'a> FnMut(ErrorResponse<'a>), + S: ExecuteDataSink, + T: CopyDataSink, +{ + type Output = S; + type CopyOutput = T; + fn rows(&mut self) -> S { + (self.0)() + } + fn copy(&mut self, copy: CopyOutResponse) -> T { + (self.1)(copy) + } + fn error(&mut self, error: ErrorResponse) { + (self.2)(error) + } +} + +pub trait ExecuteDataSink { + /// Sink a row of data. + fn row(&mut self, values: DataRow); + /// Handle the completion of a command. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +impl ExecuteDataSink for () { + fn row(&mut self, _: DataRow) {} +} + +impl ExecuteDataSink for F +where + F: for<'a> Fn(DataRow<'a>), +{ + fn row(&mut self, values: DataRow) { + (self)(values) + } +} + +/// A sink capable of handling standard query and COPY (out direction) messages. +pub trait QuerySink { + type Output: DataSink; + type CopyOutput: CopyDataSink; + + fn rows(&mut self, rows: RowDescription) -> Self::Output; + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput; + fn complete(&mut self, _complete: CommandComplete) {} + fn notice(&mut self, _: NoticeResponse) {} + fn error(&mut self, error: ErrorResponse); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DoneHandling { + Handled, + RedirectToParent, +} + +pub trait DataSink { + /// Sink a row of data. + fn row(&mut self, values: DataRow); + /// Handle the completion of a command. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +pub trait CopyDataSink { + /// Sink a chunk of COPY data. + fn data(&mut self, values: CopyData); + /// Handle the completion of a COPY operation. If unimplemented, will be redirected to the parent. + #[must_use] + fn done(&mut self, _result: Result) -> DoneHandling { + DoneHandling::RedirectToParent + } +} + +impl QuerySink for Box +where + Q: QuerySink + 'static, +{ + type Output = Box; + type CopyOutput = Box; + fn rows(&mut self, rows: RowDescription) -> Self::Output { + Box::new(self.as_mut().rows(rows)) + } + fn copy(&mut self, copy: CopyOutResponse) -> Self::CopyOutput { + Box::new(self.as_mut().copy(copy)) + } + fn complete(&mut self, _complete: CommandComplete) { + self.as_mut().complete(_complete) + } + fn error(&mut self, error: ErrorResponse) { + self.as_mut().error(error) + } +} + +impl QuerySink for () { + type Output = (); + type CopyOutput = (); + fn rows(&mut self, _: RowDescription) {} + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, _: ErrorResponse) {} +} + +impl QuerySink for (F1, F2) +where + F1: for<'a> FnMut(RowDescription<'a>) -> S, + F2: for<'a> FnMut(ErrorResponse<'a>), + S: DataSink, +{ + type Output = S; + type CopyOutput = (); + fn rows(&mut self, rows: RowDescription) -> S { + (self.0)(rows) + } + fn copy(&mut self, _: CopyOutResponse) {} + fn error(&mut self, error: ErrorResponse) { + (self.1)(error) + } +} + +impl QuerySink for (F1, F2, F3) +where + F1: for<'a> FnMut(RowDescription<'a>) -> S, + F2: for<'a> FnMut(CopyOutResponse<'a>) -> T, + F3: for<'a> FnMut(ErrorResponse<'a>), + S: DataSink, + T: CopyDataSink, +{ + type Output = S; + type CopyOutput = T; + fn rows(&mut self, rows: RowDescription) -> S { + (self.0)(rows) + } + fn copy(&mut self, copy: CopyOutResponse) -> T { + (self.1)(copy) + } + fn error(&mut self, error: ErrorResponse) { + (self.2)(error) + } +} + +impl DataSink for () { + fn row(&mut self, _: DataRow) {} +} + +impl DataSink for F +where + F: for<'a> Fn(DataRow<'a>), +{ + fn row(&mut self, values: DataRow) { + (self)(values) + } +} + +impl DataSink for Box { + fn row(&mut self, values: DataRow) { + self.as_mut().row(values) + } + fn done(&mut self, result: Result) -> DoneHandling { + self.as_mut().done(result) + } +} + +impl CopyDataSink for () { + fn data(&mut self, _: CopyData) {} +} + +impl CopyDataSink for F +where + F: for<'a> FnMut(CopyData<'a>), +{ + fn data(&mut self, values: CopyData) { + (self)(values) + } +} + +impl CopyDataSink for Box { + fn data(&mut self, values: CopyData) { + self.as_mut().data(values) + } + fn done(&mut self, result: Result) -> DoneHandling { + self.as_mut().done(result) + } +} + +pub(crate) struct ExecuteMessageHandler { + pub sink: Q, + pub data: Option, + pub copy: Option, +} + +impl MessageHandler for ExecuteMessageHandler { + fn name(&self) -> &'static str { + "Execute" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (CopyOutResponse as copy) => { + let sink = std::mem::replace(&mut self.copy, Some(self.sink.copy(copy))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + }, + (CopyData as data) => { + if let Some(sink) = &mut self.copy { + sink.data(data); + } else { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (CopyDone) => { + if self.copy.is_none() { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (DataRow as row) => { + if self.data.is_none() { + self.data = Some(self.sink.rows()); + } + let Some(sink) = &mut self.data else { + unreachable!() + }; + sink.row(row) + }, + (PortalSuspended as complete) => { + if let Some(mut sink) = std::mem::take(&mut self.data) { + if sink.done(Ok(ExecuteCompletion::PortalSuspended(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::PortalSuspended(complete)); + } + } else { + return MessageResult::UnexpectedState { complaint: "data sink does not exist" }; + } + return MessageResult::Done; + }, + (CommandComplete as complete) => { + if let Some(mut sink) = std::mem::take(&mut self.copy) { + // If COPY has started, route this to the COPY sink. + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } else if let Some(mut sink) = std::mem::take(&mut self.data) { + // If data has started, route this to the data sink. + if sink.done(Ok(ExecuteCompletion::CommandComplete(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } else { + // Otherwise, create a new data sink and route to there. + if self.sink.rows().done(Ok(ExecuteCompletion::CommandComplete(complete))) == DoneHandling::RedirectToParent { + self.sink.complete(ExecuteCompletion::CommandComplete(complete)); + } + } + return MessageResult::Done; + }, + (EmptyQueryResponse) => { + // TODO: This should be exposed to the sink + return MessageResult::Done; + }, + + (ErrorResponse as err) => { + if let Some(mut sink) = std::mem::take(&mut self.copy) { + // If COPY has started, route this to the COPY sink. + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else if let Some(mut sink) = std::mem::take(&mut self.data) { + // If data has started, route this to the data sink. + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else { + // Otherwise, create a new data sink and route to there. + if self.sink.rows().done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } + + return MessageResult::SkipUntilSync; + }, + (NoticeResponse as notice) => { + self.sink.notice(notice); + }, + + _unknown => { + return MessageResult::Unknown; + } + }); + MessageResult::Continue + } +} + +pub(crate) struct QueryMessageHandler { + pub sink: Q, + pub data: Option, + pub copy: Option, +} + +impl MessageHandler for QueryMessageHandler { + fn name(&self) -> &'static str { + "Query" + } + fn handle(&mut self, message: Message) -> MessageResult { + match_message!(Ok(message), Backend { + (CopyOutResponse as copy) => { + let sink = std::mem::replace(&mut self.copy, Some(self.sink.copy(copy))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + }, + (CopyData as data) => { + if let Some(sink) = &mut self.copy { + sink.data(data); + } else { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + (CopyDone) => { + if self.copy.is_none() { + return MessageResult::UnexpectedState { complaint: "copy sink does not exist" }; + } + }, + + (RowDescription as row) => { + let sink = std::mem::replace(&mut self.data, Some(self.sink.rows(row))); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "data sink exists" }; + } + }, + (DataRow as row) => { + if let Some(sink) = &mut self.data { + sink.row(row) + } else { + return MessageResult::UnexpectedState { complaint: "data sink does not exist" }; + } + }, + (CommandComplete as complete) => { + let sink = std::mem::take(&mut self.data); + if let Some(mut sink) = sink { + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(complete); + } + } else { + let sink = std::mem::take(&mut self.copy); + if let Some(mut sink) = sink { + if sink.done(Ok(complete)) == DoneHandling::RedirectToParent { + self.sink.complete(complete); + } + } else { + self.sink.complete(complete); + } + } + }, + + (EmptyQueryResponse) => { + // Equivalent to CommandComplete, but no data was provided + let sink = std::mem::take(&mut self.data); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "data sink exists" }; + } else { + let sink = std::mem::take(&mut self.copy); + if sink.is_some() { + return MessageResult::UnexpectedState { complaint: "copy sink exists" }; + } + } + }, + + (ErrorResponse as err) => { + // Depending on the state of the sink, we direct the error to + // the appropriate handler. + if let Some(mut sink) = std::mem::take(&mut self.data) { + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else if let Some(mut sink) = std::mem::take(&mut self.copy) { + if sink.done(Err(err)) == DoneHandling::RedirectToParent { + self.sink.error(err); + } + } else { + // Top level errors must complete this operation + self.sink.error(err); + } + }, + (NoticeResponse as notice) => { + self.sink.notice(notice); + }, + + (ReadyForQuery) => { + // All operations are complete at this point. + if std::mem::take(&mut self.data).is_some() || std::mem::take(&mut self.copy).is_some() { + return MessageResult::UnexpectedState { complaint: "sink exists" }; + } + return MessageResult::Done; + }, + + _unknown => { + return MessageResult::Unknown; + } + }); + MessageResult::Continue + } +} + +#[derive(Default)] +pub struct PipelineBuilder { + handlers: Vec>, + messages: Vec, +} + +impl PipelineBuilder { + fn push_flow_with_sink(mut self, flow: impl FlowWithSink) -> Self { + flow.visit_flow(|flow| self.messages.extend_from_slice(&flow.to_vec())); + self.handlers.push(flow.make_handler()); + self + } + + /// Add a bind flow to the pipeline. + pub fn bind( + self, + portal: Portal, + statement: Statement, + params: &[Param], + result_format_codes: &[Format], + handler: impl SimpleFlowSink + 'static, + ) -> Self { + self.push_flow_with_sink(( + BindFlow { + portal, + statement, + params, + result_format_codes, + }, + handler, + )) + } + + /// Add a parse flow to the pipeline. + pub fn parse( + self, + name: Statement, + query: &str, + param_types: &[Oid], + handler: impl SimpleFlowSink + 'static, + ) -> Self { + self.push_flow_with_sink(( + ParseFlow { + name, + query, + param_types, + }, + handler, + )) + } + + /// Add an execute flow to the pipeline. + /// + /// Note that this may be a COPY statement. In that case, the description of the portal + /// will not show any data returned, and this will use the `CopySink` of the provided + /// sink. In addition, COPY operations do not respect the `max_rows` parameter. + pub fn execute( + self, + portal: Portal, + max_rows: MaxRows, + handler: impl ExecuteSink + 'static, + ) -> Self { + self.push_flow_with_sink((ExecuteFlow { portal, max_rows }, handler)) + } + + /// Add a close portal flow to the pipeline. + pub fn close_portal(self, name: Portal, handler: impl SimpleFlowSink + 'static) -> Self { + self.push_flow_with_sink((ClosePortalFlow { name }, handler)) + } + + /// Add a close statement flow to the pipeline. + pub fn close_statement(self, name: Statement, handler: impl SimpleFlowSink + 'static) -> Self { + self.push_flow_with_sink((CloseStatementFlow { name }, handler)) + } + + /// Add a describe portal flow to the pipeline. Note that this will describe + /// both parameters and rows. + pub fn describe_portal(self, name: Portal, handler: impl DescribeSink + 'static) -> Self { + self.push_flow_with_sink((DescribePortalFlow { name }, handler)) + } + + /// Add a describe statement flow to the pipeline. Note that this will describe + /// only the rows of the portal. + pub fn describe_statement(self, name: Statement, handler: impl DescribeSink + 'static) -> Self { + self.push_flow_with_sink((DescribeStatementFlow { name }, handler)) + } + + /// Add a query flow to the pipeline. + /// + /// Note that if a query fails, the pipeline will continue executing until it + /// completes or a non-query pipeline element fails. If a previous non-query + /// element of this pipeline failed, the query will not be executed. + pub fn query(self, query: &str, handler: impl QuerySink + 'static) -> Self { + self.push_flow_with_sink((QueryFlow { query }, handler)) + } + + pub fn build(self) -> Pipeline { + Pipeline { + handlers: self.handlers, + messages: self.messages, + } + } +} + +pub struct Pipeline { + pub(crate) handlers: Vec>, + pub(crate) messages: Vec, +} + +#[derive(Default)] +/// Accumulate raw messages from a flow. Useful mainly for testing. +pub struct FlowAccumulator { + data: Vec, + messages: Vec, +} + +impl FlowAccumulator { + pub fn push(&mut self, message: impl AsRef<[u8]>) { + self.messages.push(self.data.len()); + self.data.extend_from_slice(message.as_ref()); + } + + pub fn with_messages(&self, mut f: impl FnMut(Message)) { + for &offset in &self.messages { + // First get the message header + let message = Message::new(&self.data[offset..]).unwrap(); + let len = message.mlen(); + // Then resize the message to the correct length + let message = Message::new(&self.data[offset..offset + len + 1]).unwrap(); + f(message); + } + } +} + +impl QuerySink for Rc> { + type Output = Self; + type CopyOutput = Self; + fn rows(&mut self, message: RowDescription) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn copy(&mut self, message: CopyOutResponse) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn error(&mut self, message: ErrorResponse) { + self.borrow_mut().push(message); + } + fn complete(&mut self, complete: CommandComplete) { + self.borrow_mut().push(complete); + } + fn notice(&mut self, message: NoticeResponse) { + self.borrow_mut().push(message); + } +} + +impl ExecuteSink for Rc> { + type Output = Self; + type CopyOutput = Self; + + fn rows(&mut self) -> Self { + self.clone() + } + fn copy(&mut self, message: CopyOutResponse) -> Self { + self.borrow_mut().push(message); + self.clone() + } + fn error(&mut self, message: ErrorResponse) { + self.borrow_mut().push(message); + } + fn complete(&mut self, complete: ExecuteCompletion) { + match complete { + ExecuteCompletion::PortalSuspended(suspended) => self.borrow_mut().push(suspended), + ExecuteCompletion::CommandComplete(complete) => self.borrow_mut().push(complete), + } + } + fn notice(&mut self, message: NoticeResponse) { + self.borrow_mut().push(message); + } +} + +impl DataSink for Rc> { + fn row(&mut self, message: DataRow) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl ExecuteDataSink for Rc> { + fn row(&mut self, message: DataRow) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(ExecuteCompletion::PortalSuspended(suspended)) => self.borrow_mut().push(suspended), + Ok(ExecuteCompletion::CommandComplete(complete)) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl CopyDataSink for Rc> { + fn data(&mut self, message: CopyData) { + self.borrow_mut().push(message); + } + fn done(&mut self, result: Result) -> DoneHandling { + match result { + Ok(complete) => self.borrow_mut().push(complete), + Err(err) => self.borrow_mut().push(err), + }; + DoneHandling::Handled + } +} + +impl SimpleFlowSink for Rc> { + fn handle(&mut self, result: Result<(), ErrorResponse>) { + match result { + Ok(()) => (), + Err(err) => self.borrow_mut().push(err), + } + } +} + +impl DescribeSink for Rc> { + fn params(&mut self, params: ParameterDescription) { + self.borrow_mut().push(params); + } + fn rows(&mut self, rows: RowDescription) { + self.borrow_mut().push(rows); + } + fn error(&mut self, error: ErrorResponse) { + self.borrow_mut().push(error); + } +} diff --git a/rust/pgrust/src/connection/mod.rs b/rust/pgrust/src/connection/mod.rs index e15be003092..c0a7e929759 100644 --- a/rust/pgrust/src/connection/mod.rs +++ b/rust/pgrust/src/connection/mod.rs @@ -7,13 +7,19 @@ use crate::{ mod conn; pub mod dsn; +mod flow; pub mod openssl; +pub(crate) mod queue; mod raw_conn; mod stream; pub mod tokio; -pub use conn::Client; +pub use conn::{Client, PGConnError}; use dsn::HostType; +pub use flow::{ + CopyDataSink, DataSink, DoneHandling, ExecuteSink, FlowAccumulator, Format, MaxRows, Oid, + Param, Pipeline, PipelineBuilder, Portal, QuerySink, Statement, +}; pub use raw_conn::connect_raw_ssl; macro_rules! __invalid_state { @@ -91,7 +97,7 @@ pub struct Credentials { pub server_settings: HashMap, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, derive_more::From)] /// The resolved target of a connection attempt. pub enum ResolvedTarget { SocketAddr(std::net::SocketAddr), @@ -100,6 +106,17 @@ pub enum ResolvedTarget { } impl ResolvedTarget { + #[cfg(test)] + pub fn from_captive_server_listen_address(address: captive_postgres::ListenAddress) -> Self { + match address { + captive_postgres::ListenAddress::Tcp(addr) => Self::SocketAddr(addr), + #[cfg(unix)] + captive_postgres::ListenAddress::Unix(path) => { + Self::UnixSocketAddr(std::os::unix::net::SocketAddr::from_pathname(path).unwrap()) + } + } + } + /// Resolves the target addresses for a given host. pub fn to_addrs_sync(host: &dsn::Host) -> Result, std::io::Error> { use std::net::{SocketAddr, ToSocketAddrs}; diff --git a/rust/pgrust/src/connection/queue.rs b/rust/pgrust/src/connection/queue.rs new file mode 100644 index 00000000000..211beed5809 --- /dev/null +++ b/rust/pgrust/src/connection/queue.rs @@ -0,0 +1,166 @@ +use std::future::Future; +use std::ops::DerefMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A queue of futures that can be polled in order. +/// +/// Only one future will be active at a time. If no futures are active, the +/// waker will be triggered when the next future is submitted to the queue. +pub struct FutureQueue { + queue: tokio::sync::mpsc::UnboundedReceiver>>>, + sender: tokio::sync::mpsc::UnboundedSender>>>, + current: Option>>>, +} + +#[cfg(test)] +#[derive(Clone)] +pub struct FutureQueueSender { + sender: tokio::sync::mpsc::UnboundedSender>>>, +} + +#[cfg(test)] +impl FutureQueueSender { + pub fn submit(&self, future: impl Future + 'static) { + // This will never fail because the receiver still exists + self.sender.send(Box::pin(future)).unwrap(); + } +} + +impl FutureQueue { + #[cfg(test)] + pub fn sender(&self) -> FutureQueueSender { + FutureQueueSender { + sender: self.sender.clone(), + } + } + + pub fn submit(&self, future: impl Future + 'static) { + // This will never fail because we hold both ends of the channel. + self.sender.send(Box::pin(future)).unwrap(); + } + + /// Poll the current future, or no current future, poll for the next item + /// from the queue (and then poll that future). + pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(future) = self.current.as_mut() { + match future.as_mut().poll(cx) { + Poll::Ready(output) => { + self.current = None; + return Poll::Ready(Some(output)); + } + Poll::Pending => return Poll::Pending, + } + } + + // If there is no current future, try to receive the next one from the queue. + let next = match self.queue.poll_recv(cx) { + Poll::Ready(Some(next)) => next, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + }; + + // Note that we loop around to poll this future until we get a Pending + // result. + self.current = Some(next); + } + } +} + +impl Default for FutureQueue { + fn default() -> Self { + let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); + Self { + queue: receiver, + sender, + current: None, + } + } +} + +impl futures::Stream for FutureQueue { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // We're Unpin + let this = self.deref_mut(); + this.poll_next_unpin(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::StreamExt; + use tokio::{ + task::LocalSet, + time::{sleep, Duration}, + }; + + #[tokio::test] + async fn test_basic_queue() { + LocalSet::new() + .run_until(async { + let mut queue = FutureQueue::default(); + let sender = queue.sender(); + + // Spawn a task that sends some futures + tokio::task::spawn_local(async move { + sleep(Duration::from_millis(10)).await; + sender.submit(async { 1 }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { 2 }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { 3 }); + }); + + // Collect results + let mut results = Vec::new(); + while let Some(value) = queue.next().await { + results.push(value); + if results.len() == 3 { + break; + } + } + + assert_eq!(results, vec![1, 2, 3]); + }) + .await; + } + + #[tokio::test] + async fn test_delayed_futures() { + LocalSet::new() + .run_until(async { + let mut queue = FutureQueue::default(); + let sender = queue.sender(); + + // Spawn task with delayed futures + tokio::task::spawn_local(async move { + sleep(Duration::from_millis(10)).await; + sender.submit(async { + sleep(Duration::from_millis(50)).await; + 1 + }); + sleep(Duration::from_millis(10)).await; + sender.submit(async { + sleep(Duration::from_millis(10)).await; + 2 + }); + }); + + // Even though second future completes first, results should be in order of sending + let mut results = Vec::new(); + while let Some(value) = queue.next().await { + results.push(value); + if results.len() == 2 { + break; + } + } + + assert_eq!(results, vec![1, 2]); + }) + .await; + } +} diff --git a/rust/pgrust/src/connection/raw_conn.rs b/rust/pgrust/src/connection/raw_conn.rs index 24d16885e23..dbd0930976e 100644 --- a/rust/pgrust/src/connection/raw_conn.rs +++ b/rust/pgrust/src/connection/raw_conn.rs @@ -1,5 +1,5 @@ use super::{ - stream::{Stream, StreamWithUpgrade, UpgradableStream}, + stream::{Stream, StreamWithUpgrade, UpgradableStream, UpgradableStreamChoice}, ConnectionError, Credentials, }; use crate::handshake::{ @@ -147,10 +147,20 @@ pub struct RawClient where (B, C): StreamWithUpgrade, { - stream: UpgradableStream, + stream: UpgradableStreamChoice, params: ConnectionParams, } +impl RawClient { + /// Create a new raw client from a stream. The stream must be fully authenticated and ready. + pub fn new(stream: B, params: ConnectionParams) -> Self { + Self { + stream: UpgradableStreamChoice::Base(stream), + params, + } + } +} + impl RawClient where (B, C): StreamWithUpgrade, @@ -169,7 +179,10 @@ where cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_read(cx, buf) + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_read(cx, buf), + UpgradableStreamChoice::Upgrade(upgraded) => Pin::new(upgraded).poll_read(cx, buf), + } } } @@ -182,18 +195,47 @@ where cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_write(cx, buf) + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_write(cx, buf), + UpgradableStreamChoice::Upgrade(upgraded) => Pin::new(upgraded).poll_write(cx, buf), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_write_vectored(cx, bufs), + UpgradableStreamChoice::Upgrade(upgraded) => { + Pin::new(upgraded).poll_write_vectored(cx, bufs) + } + } + } + + fn is_write_vectored(&self) -> bool { + match &self.stream { + UpgradableStreamChoice::Base(base) => base.is_write_vectored(), + UpgradableStreamChoice::Upgrade(upgraded) => upgraded.is_write_vectored(), + } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_flush(cx) + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_flush(cx), + UpgradableStreamChoice::Upgrade(upgraded) => Pin::new(upgraded).poll_flush(cx), + } } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_shutdown(cx) + match &mut self.get_mut().stream { + UpgradableStreamChoice::Base(base) => Pin::new(base).poll_shutdown(cx), + UpgradableStreamChoice::Upgrade(upgraded) => Pin::new(upgraded).poll_shutdown(cx), + } } } @@ -245,6 +287,8 @@ where .drive_bytes(&mut state, &buffer[..n], &mut struct_buffer, &mut stream) .await?; } + + let stream = stream.into_choice().unwrap(); Ok(RawClient { stream, params: update.params, diff --git a/rust/pgrust/src/connection/stream.rs b/rust/pgrust/src/connection/stream.rs index 6606a083c93..07cdc91273c 100644 --- a/rust/pgrust/src/connection/stream.rs +++ b/rust/pgrust/src/connection/stream.rs @@ -33,6 +33,7 @@ impl StreamWithUpgrade for (S, ()) { } } +#[derive(derive_more::Debug)] pub struct UpgradableStream where (B, C): StreamWithUpgrade, @@ -79,6 +80,19 @@ where )), } } + + /// Convert the inner stream into a choice between the base and the upgraded stream. + /// + /// If the inner stream is in the process of upgrading, return an error containing `self`. + pub fn into_choice(self) -> Result, Self> { + match self.inner { + UpgradableStreamInner::Base(base, _) => Ok(UpgradableStreamChoice::Base(base)), + UpgradableStreamInner::Upgraded(upgraded) => { + Ok(UpgradableStreamChoice::Upgrade(upgraded)) + } + UpgradableStreamInner::Upgrading => Err(self), + } + } } impl tokio::io::AsyncRead for UpgradableStream @@ -185,11 +199,26 @@ where } } +#[derive(derive_more::Debug)] enum UpgradableStreamInner where (B, C): StreamWithUpgrade, { + #[debug("Base(..)")] Base(B, C), + #[debug("Upgraded(..)")] Upgraded(<(B, C) as StreamWithUpgrade>::Upgrade), + #[debug("Upgrading")] Upgrading, } + +#[derive(derive_more::Debug)] +pub enum UpgradableStreamChoice +where + (B, C): StreamWithUpgrade, +{ + #[debug("Base(..)")] + Base(B), + #[debug("Upgrade(..)")] + Upgrade(<(B, C) as StreamWithUpgrade>::Upgrade), +} diff --git a/rust/pgrust/src/protocol/datatypes.rs b/rust/pgrust/src/protocol/datatypes.rs index 99433de3ce4..b06b98575ea 100644 --- a/rust/pgrust/src/protocol/datatypes.rs +++ b/rust/pgrust/src/protocol/datatypes.rs @@ -385,6 +385,15 @@ pub enum Encoded<'a> { Value(&'a [u8]), } +impl<'a> Encoded<'a> { + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + match self { + Encoded::Null => "".into(), + Encoded::Value(value) => String::from_utf8_lossy(value), + } + } +} + impl<'a> AsRef> for Encoded<'a> { fn as_ref(&self) -> &Encoded<'a> { self @@ -443,10 +452,10 @@ impl FieldAccess { const N: usize = std::mem::size_of::(); if let Some(len) = buf.first_chunk::() { let len = i32::from_be_bytes(*len); - if len < 0 { - Err(ParseError::InvalidData) - } else if len == -1 { + if len == -1 { Ok(N) + } else if len < 0 { + Err(ParseError::InvalidData) } else if buf.len() < len as usize + N { Err(ParseError::TooShort) } else { @@ -461,10 +470,10 @@ impl FieldAccess { const N: usize = std::mem::size_of::(); if let Some((len, array)) = buf.split_first_chunk::() { let len = i32::from_be_bytes(*len); - if len < 0 { - Err(ParseError::InvalidData) - } else if len == -1 { + if len == -1 && array.is_empty() { Ok(Encoded::Null) + } else if len < 0 { + Err(ParseError::InvalidData) } else if array.len() < len as _ { Err(ParseError::TooShort) } else { diff --git a/rust/pgrust/src/protocol/gen.rs b/rust/pgrust/src/protocol/gen.rs index 19dca193326..e0fb066ea65 100644 --- a/rust/pgrust/src/protocol/gen.rs +++ b/rust/pgrust/src/protocol/gen.rs @@ -276,6 +276,7 @@ macro_rules! protocol_builder { $( " (value = `", stringify!($value), "`)", )? "\n\n" )* )] + #[derive(Copy, Clone)] pub struct $name<'a> { /// Our zero-copy buffer. #[doc(hidden)] @@ -371,7 +372,7 @@ macro_rules! protocol_builder { }) } - pub fn to_vec(&self) -> Vec { + pub fn to_vec(self) -> Vec { self.__buf.to_vec() } diff --git a/rust/pgrust/src/protocol/message_group.rs b/rust/pgrust/src/protocol/message_group.rs index 0f5f0720857..7201864dc87 100644 --- a/rust/pgrust/src/protocol/message_group.rs +++ b/rust/pgrust/src/protocol/message_group.rs @@ -21,7 +21,7 @@ macro_rules! message_group { #[allow(unused)] impl [<$group Builder>]<'_> { - pub fn to_vec(&self) -> Vec { + pub fn to_vec(self) -> Vec { match self { $( Self::$message(message) => message.to_vec(), diff --git a/rust/pgrust/src/protocol/mod.rs b/rust/pgrust/src/protocol/mod.rs index aead9728e4e..568272eabc6 100644 --- a/rust/pgrust/src/protocol/mod.rs +++ b/rust/pgrust/src/protocol/mod.rs @@ -610,6 +610,17 @@ mod tests { fuzz_test::(message); } + #[test] + fn test_datarow() { + let buf = [ + 0x44, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, + ]; + assert!(DataRow::is_buffer(&buf)); + let message = DataRow::new(&buf).unwrap(); + assert_eq!(message.values().len(), 1); + assert_eq!(message.values().into_iter().next().unwrap(), Encoded::Null); + } + #[test] fn test_edgedb_sasl() { use crate::protocol::edgedb::*; diff --git a/rust/pgrust/src/protocol/postgres.rs b/rust/pgrust/src/protocol/postgres.rs index 04bbdc106f5..df313ef106a 100644 --- a/rust/pgrust/src/protocol/postgres.rs +++ b/rust/pgrust/src/protocol/postgres.rs @@ -282,7 +282,7 @@ struct Close: Message { mtype: u8 = 'C', /// Length of message contents in bytes, including self. mlen: len, - /// 'xS' to close a prepared statement; 'P' to close a portal. + /// 'S' to close a prepared statement; 'P' to close a portal. ctype: u8, /// The name of the prepared statement or portal to close. name: ZTString, @@ -564,7 +564,7 @@ struct Parse: Message { mlen: len, /// The name of the destination prepared statement. statement: ZTString, - /// The query String to be parsed. + /// The query string to be parsed. query: ZTString, /// OIDs of the parameter data types. param_types: Array, diff --git a/rust/pgrust/tests/query_real_postgres.rs b/rust/pgrust/tests/query_real_postgres.rs new file mode 100644 index 00000000000..f8dde14716d --- /dev/null +++ b/rust/pgrust/tests/query_real_postgres.rs @@ -0,0 +1,354 @@ +use std::cell::RefCell; +use std::future::Future; +use std::num::NonZero; +use std::rc::Rc; + +// Constants +use gel_auth::AuthType; +use pgrust::connection::tokio::TokioStream; +use pgrust::connection::{ + Client, Credentials, FlowAccumulator, MaxRows, Oid, Param, PipelineBuilder, Portal, + ResolvedTarget, Statement, +}; +use pgrust::protocol::match_message; +use pgrust::protocol::postgres::data::*; +use tokio::task::LocalSet; + +use captive_postgres::*; + +fn address(address: &ListenAddress) -> ResolvedTarget { + match address { + ListenAddress::Tcp(addr) => ResolvedTarget::SocketAddr(*addr), + #[cfg(unix)] + ListenAddress::Unix(path) => ResolvedTarget::UnixSocketAddr( + std::os::unix::net::SocketAddr::from_pathname(path).unwrap(), + ), + } +} + +async fn with_postgres(callback: F) -> Result, Box> +where + F: FnOnce(Client, Rc>) -> R, + R: Future>>, +{ + let Some(postgres_process) = setup_postgres(AuthType::Trust, Mode::Tcp)? else { + return Ok(None); + }; + + let credentials = Credentials { + username: DEFAULT_USERNAME.to_string(), + password: DEFAULT_PASSWORD.to_string(), + database: DEFAULT_DATABASE.to_string(), + server_settings: Default::default(), + }; + + let socket = address(&postgres_process.socket_address).connect().await?; + let (client, task) = Client::new(credentials, socket, ()); + let accumulator = Rc::new(RefCell::new(FlowAccumulator::default())); + + let accumulator2 = accumulator.clone(); + LocalSet::new() + .run_until(async move { + tokio::task::spawn_local(task); + client.ready().await?; + callback(client, accumulator2.clone()).await?; + Result::<(), Box>::Ok(()) + }) + .await?; + + let mut s = String::new(); + accumulator.borrow().with_messages(|message| { + match_message!(Ok(message), Backend { + (ParameterDescription as params) => { + // OID values are not guaranteed to be stable, so we just print "..." instead. + s.push_str(&format!("ParameterDescription {:?}\n", params.param_types().into_iter().map(|_| "...").collect::>())); + }, + (RowDescription as rows) => { + s.push_str(&format!("RowDescription {}\n", rows.fields().into_iter().map(|f| f.name().to_string_lossy().into_owned()).collect::>().join(", "))); + }, + (PortalSuspended) => { + s.push_str("PortalSuspended\n"); + }, + (ErrorResponse as err) => { + for field in err.fields() { + if field.etype() as char == 'C' { + s.push_str(&format!("ErrorResponse {}\n", field.value().to_string_lossy())); + return; + } + } + s.push_str(&format!("ErrorResponse {:?}\n", err)); + }, + (NoticeResponse as notice) => { + for field in notice.fields() { + if field.ntype() as char == 'M' { + s.push_str(&format!("NoticeResponse {}\n", field.value().to_string_lossy())); + return; + } + } + s.push_str(&format!("NoticeResponse {:?}\n", notice)); + }, + (CommandComplete as cmd) => { + s.push_str(&format!("CommandComplete {:?}\n", cmd.tag())); + }, + (DataRow as row) => { + s.push_str(&format!("DataRow {}\n", row.values().into_iter().map(|v| v.to_string_lossy().into_owned()).collect::>().join(", "))); + }, + (CopyData as copy_data) => { + s.push_str(&format!("CopyData {:?}\n", String::from_utf8_lossy(©_data.data()))); + }, + (CopyOutResponse as copy_out) => { + s.push_str(&format!("CopyOutResponse {}\n", copy_out.format())); + }, + _unknown => { + s.push_str("Unknown\n"); + } + }) + }); + + Ok(Some(s)) +} + +#[test_log::test(tokio::test)] +async fn test_query() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client.query("SELECT 1", accumulator.clone()).await?; + Ok(()) + }) + .await? + { + assert_eq!( + s, + "RowDescription ?column?\nDataRow 1\nCommandComplete \"SELECT 1\"\n" + ); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_success() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse( + Statement("test"), + "SELECT $1", + &[Oid::unspecified()], + accumulator.clone(), + ) + .describe_statement(Statement("test"), accumulator.clone()) + .bind( + Portal("test"), + Statement("test"), + &[Param::Text("1")], + &[], + accumulator.clone(), + ) + .describe_portal(Portal("test"), accumulator.clone()) + .execute( + Portal("test"), + MaxRows::Limited(NonZero::new(1).unwrap()), + accumulator.clone(), + ) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "ParameterDescription [\"...\"]\nRowDescription ?column?\nRowDescription ?column?\nDataRow 1\nPortalSuspended\n"); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_parse_error() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse(Statement("test"), ".", &[], accumulator.clone()) + .bind( + Portal("test"), + Statement("test"), + &[], + &[], + accumulator.clone(), + ) + .query("SELECT 1", accumulator.clone()) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "ErrorResponse 42601\n"); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_portal_suspended() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse( + Statement("test"), + "SELECT generate_series(1,3)", + &[], + accumulator.clone(), + ) + .bind( + Portal("test"), + Statement("test"), + &[], + &[], + accumulator.clone(), + ) + .execute( + Portal("test"), + MaxRows::Limited(NonZero::new(2).unwrap()), + accumulator.clone(), + ) + .execute( + Portal("test"), + MaxRows::Limited(NonZero::new(2).unwrap()), + accumulator.clone(), + ) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!( + s, + "DataRow 1\nDataRow 2\nPortalSuspended\nDataRow 3\nCommandComplete \"SELECT 1\"\n" + ); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_copy() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse( + Statement("test"), + "COPY (SELECT 1) TO STDOUT", + &[], + accumulator.clone(), + ) + .bind( + Portal("test"), + Statement("test"), + &[], + &[], + accumulator.clone(), + ) + .execute(Portal("test"), MaxRows::Unlimited, accumulator.clone()) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!( + s, + "CopyOutResponse 0\nCopyData \"1\\n\"\nCommandComplete \"COPY 1\"\n" + ); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_extended_query_empty() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .parse(Statement("test"), "", &[], accumulator.clone()) + .bind( + Portal("test"), + Statement("test"), + &[], + &[], + accumulator.clone(), + ) + .execute(Portal("test"), MaxRows::Unlimited, accumulator.clone()) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, ""); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_query_notice() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + // DO block with NOTICE RAISE generates a notice + client + .query( + "DO $$ BEGIN RAISE NOTICE 'test notice'; END $$;", + accumulator.clone(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "NoticeResponse test notice\nCommandComplete \"DO\"\n"); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_query_warning() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + // DO block with WARNING RAISE generates a warning + client + .query( + "DO $$ BEGIN RAISE WARNING 'test warning'; END $$;", + accumulator.clone(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "NoticeResponse test warning\nCommandComplete \"DO\"\n"); + } + Ok(()) +} + +#[test_log::test(tokio::test)] +async fn test_double_begin_transaction() -> Result<(), Box> { + if let Some(s) = with_postgres(|client, accumulator| async move { + client + .pipeline_sync( + PipelineBuilder::default() + .query("BEGIN TRANSACTION", accumulator.clone()) + .query("BEGIN TRANSACTION", accumulator.clone()) + .build(), + ) + .await?; + Ok(()) + }) + .await? + { + assert_eq!(s, "CommandComplete \"BEGIN\"\nNoticeResponse there is already a transaction in progress\nCommandComplete \"BEGIN\"\n"); + } + Ok(()) +} diff --git a/rust/pgrust/tests/real_postgres.rs b/rust/pgrust/tests/real_postgres.rs index d9238e6f5d9..0d562e2047e 100644 --- a/rust/pgrust/tests/real_postgres.rs +++ b/rust/pgrust/tests/real_postgres.rs @@ -1,378 +1,22 @@ // Constants use gel_auth::AuthType; -use openssl::ssl::{Ssl, SslContext, SslMethod}; -use pgrust::connection::dsn::{Host, HostType}; use pgrust::connection::{connect_raw_ssl, ConnectionError, Credentials, ResolvedTarget}; use pgrust::errors::PgServerError; use pgrust::handshake::ConnectionSslRequirement; use rstest::rstest; -use std::io::{BufRead, BufReader, Write}; -use std::net::{Ipv4Addr, SocketAddr, TcpListener}; -use std::os::unix::fs::PermissionsExt; -use std::path::{Path, PathBuf}; -use std::process::{Command, Stdio}; -use std::sync::{Arc, RwLock}; -use std::thread; -use std::time::{Duration, Instant}; -use tempfile::TempDir; -const STARTUP_TIMEOUT_DURATION: Duration = Duration::from_secs(30); -const PORT_RELEASE_TIMEOUT: Duration = Duration::from_secs(30); -const LINGER_DURATION: Duration = Duration::from_secs(1); -const HOT_LOOP_INTERVAL: Duration = Duration::from_millis(100); -const DEFAULT_USERNAME: &str = "username"; -const DEFAULT_PASSWORD: &str = "password"; -const DEFAULT_DATABASE: &str = "postgres"; +use captive_postgres::*; -/// Represents an ephemeral port that can be allocated and released for immediate re-use by another process. -struct EphemeralPort { - port: u16, - listener: Option, -} - -impl EphemeralPort { - /// Allocates a new ephemeral port. - /// - /// Returns a Result containing the EphemeralPort if successful, - /// or an IO error if the allocation fails. - fn allocate() -> std::io::Result { - let socket = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; - socket.set_reuse_address(true)?; - socket.set_reuse_port(true)?; - socket.set_linger(Some(LINGER_DURATION))?; - socket.bind(&std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, 0)).into())?; - socket.listen(1)?; - let listener = TcpListener::from(socket); - let port = listener.local_addr()?.port(); - Ok(EphemeralPort { - port, - listener: Some(listener), - }) - } - - /// Consumes the EphemeralPort and returns the allocated port number. - fn take(self) -> u16 { - // Drop the listener to free up the port - drop(self.listener); - - // Loop until the port is free - let start = Instant::now(); - - // If we can successfully connect to the port, it's not fully closed - while start.elapsed() < PORT_RELEASE_TIMEOUT { - let res = std::net::TcpStream::connect((Ipv4Addr::LOCALHOST, self.port)); - if res.is_err() { - // If connection fails, the port is released - break; - } - std::thread::sleep(HOT_LOOP_INTERVAL); - } - - self.port +fn address(address: &ListenAddress) -> ResolvedTarget { + match address { + ListenAddress::Tcp(addr) => ResolvedTarget::SocketAddr(*addr), + #[cfg(unix)] + ListenAddress::Unix(path) => ResolvedTarget::UnixSocketAddr( + std::os::unix::net::SocketAddr::from_pathname(path).unwrap(), + ), } } -struct StdioReader { - output: Arc>, -} - -impl StdioReader { - fn spawn(reader: R, prefix: &'static str) -> Self { - let output = Arc::new(RwLock::new(String::new())); - let output_clone = Arc::clone(&output); - - thread::spawn(move || { - let mut buf_reader = std::io::BufReader::new(reader); - loop { - let mut line = String::new(); - match buf_reader.read_line(&mut line) { - Ok(0) => break, - Ok(_) => { - if let Ok(mut output) = output_clone.write() { - output.push_str(&line); - } - eprint!("[{}]: {}", prefix, line); - } - Err(e) => { - let error_line = format!("Error reading {}: {}\n", prefix, e); - if let Ok(mut output) = output_clone.write() { - output.push_str(&error_line); - } - eprintln!("{}", error_line); - } - } - } - }); - - StdioReader { output } - } - - fn contains(&self, s: &str) -> bool { - if let Ok(output) = self.output.read() { - output.contains(s) - } else { - false - } - } -} - -fn init_postgres(initdb: &Path, data_dir: &Path, auth: AuthType) -> std::io::Result<()> { - let mut pwfile = tempfile::NamedTempFile::new()?; - writeln!(pwfile, "{}", DEFAULT_PASSWORD)?; - let mut command = Command::new(initdb); - command - .arg("-D") - .arg(data_dir) - .arg("-A") - .arg(match auth { - AuthType::Deny => "reject", - AuthType::Trust => "trust", - AuthType::Plain => "password", - AuthType::Md5 => "md5", - AuthType::ScramSha256 => "scram-sha-256", - }) - .arg("--pwfile") - .arg(pwfile.path()) - .arg("-U") - .arg(DEFAULT_USERNAME); - - let output = command.output()?; - - let status = output.status; - let output_str = String::from_utf8_lossy(&output.stdout).to_string(); - let error_str = String::from_utf8_lossy(&output.stderr).to_string(); - - eprintln!("initdb stdout:\n{}", output_str); - eprintln!("initdb stderr:\n{}", error_str); - - if !status.success() { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "initdb command failed", - )); - } - - Ok(()) -} - -fn run_postgres( - postgres_bin: &Path, - data_dir: &Path, - socket_path: &Path, - ssl: Option<(PathBuf, PathBuf)>, - port: u16, -) -> std::io::Result { - let mut command = Command::new(postgres_bin); - command - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .arg("-D") - .arg(data_dir) - .arg("-k") - .arg(socket_path) - .arg("-h") - .arg(Ipv4Addr::LOCALHOST.to_string()) - .arg("-F") - // Useful for debugging - // .arg("-d") - // .arg("5") - .arg("-p") - .arg(port.to_string()); - - if let Some((cert_path, key_path)) = ssl { - let postgres_cert_path = data_dir.join("server.crt"); - let postgres_key_path = data_dir.join("server.key"); - std::fs::copy(cert_path, &postgres_cert_path)?; - std::fs::copy(key_path, &postgres_key_path)?; - // Set permissions for the certificate and key files - std::fs::set_permissions(&postgres_cert_path, std::fs::Permissions::from_mode(0o600))?; - std::fs::set_permissions(&postgres_key_path, std::fs::Permissions::from_mode(0o600))?; - - // Edit pg_hba.conf to change all "host" line prefixes to "hostssl" - let pg_hba_path = data_dir.join("pg_hba.conf"); - let content = std::fs::read_to_string(&pg_hba_path)?; - let modified_content = content - .lines() - .filter(|line| !line.starts_with("#") && !line.is_empty()) - .map(|line| { - if line.trim_start().starts_with("host") { - line.replacen("host", "hostssl", 1) - } else { - line.to_string() - } - }) - .collect::>() - .join("\n"); - eprintln!("pg_hba.conf:\n==========\n{modified_content}\n=========="); - std::fs::write(&pg_hba_path, modified_content)?; - - command.arg("-l"); - } - - let mut child = command.spawn()?; - - let stdout_reader = BufReader::new(child.stdout.take().expect("Failed to capture stdout")); - let _ = StdioReader::spawn(stdout_reader, "stdout"); - let stderr_reader = BufReader::new(child.stderr.take().expect("Failed to capture stderr")); - let stderr_reader = StdioReader::spawn(stderr_reader, "stderr"); - - let start_time = Instant::now(); - - let mut tcp_socket: Option = None; - let mut unix_socket: Option = None; - - let unix_socket_path = get_unix_socket_path(socket_path, port); - let tcp_socket_addr = std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, port)); - let mut db_ready = false; - - while start_time.elapsed() < STARTUP_TIMEOUT_DURATION { - std::thread::sleep(HOT_LOOP_INTERVAL); - match child.try_wait() { - Ok(Some(status)) => { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - format!("PostgreSQL exited with status: {}", status), - )) - } - Err(e) => return Err(e), - _ => {} - } - if !db_ready && stderr_reader.contains("database system is ready to accept connections") { - eprintln!("Database is ready"); - db_ready = true; - } else { - continue; - } - if unix_socket.is_none() { - unix_socket = std::os::unix::net::UnixStream::connect(&unix_socket_path).ok(); - } - if tcp_socket.is_none() { - tcp_socket = std::net::TcpStream::connect(tcp_socket_addr).ok(); - } - if unix_socket.is_some() && tcp_socket.is_some() { - break; - } - } - - if unix_socket.is_some() && tcp_socket.is_some() { - return Ok(child); - } - - // Print status for TCP/unix sockets - if let Some(tcp) = &tcp_socket { - eprintln!( - "TCP socket at {tcp_socket_addr:?} bound successfully on {}", - tcp.local_addr()? - ); - } else { - eprintln!("TCP socket at {tcp_socket_addr:?} binding failed"); - } - - if unix_socket.is_some() { - eprintln!("Unix socket at {unix_socket_path:?} connected successfully"); - } else { - eprintln!("Unix socket at {unix_socket_path:?} connection failed"); - } - - Err(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "PostgreSQL failed to start within 30 seconds", - )) -} - -fn test_data_dir() -> std::path::PathBuf { - Path::new("../../../tests") - .canonicalize() - .expect("Failed to canonicalize tests directory path") -} - -fn postgres_bin_dir() -> std::io::Result { - Path::new("../../../build/postgres/install/bin").canonicalize() -} - -fn get_unix_socket_path(socket_path: &Path, port: u16) -> PathBuf { - socket_path.join(format!(".s.PGSQL.{}", port)) -} - -#[derive(Debug, Clone, Copy)] -enum Mode { - Tcp, - TcpSsl, - Unix, -} - -fn create_ssl_client() -> Result> { - let ssl_context = SslContext::builder(SslMethod::tls_client())?.build(); - let mut ssl = Ssl::new(&ssl_context)?; - ssl.set_connect_state(); - Ok(ssl) -} -struct PostgresProcess { - child: std::process::Child, - socket_address: ResolvedTarget, - #[allow(unused)] - temp_dir: TempDir, -} - -impl Drop for PostgresProcess { - fn drop(&mut self) { - let _ = self.child.kill(); - } -} - -fn setup_postgres( - auth: AuthType, - mode: Mode, -) -> Result, Box> { - let Ok(bindir) = postgres_bin_dir() else { - println!("Skipping test: postgres bin dir not found"); - return Ok(None); - }; - - let initdb = bindir.join("initdb"); - let postgres = bindir.join("postgres"); - - if !initdb.exists() || !postgres.exists() { - println!("Skipping test: initdb or postgres not found"); - return Ok(None); - } - - let temp_dir = TempDir::new()?; - let port = EphemeralPort::allocate()?; - let data_dir = temp_dir.path().join("data"); - - init_postgres(&initdb, &data_dir, auth)?; - let ssl_key = match mode { - Mode::TcpSsl => { - let certs_dir = test_data_dir().join("certs"); - let cert = certs_dir.join("server.cert.pem"); - let key = certs_dir.join("server.key.pem"); - Some((cert, key)) - } - _ => None, - }; - - let port = port.take(); - let child = run_postgres(&postgres, &data_dir, &data_dir, ssl_key, port)?; - - let socket_address = match mode { - Mode::Unix => ResolvedTarget::to_addrs_sync(&Host( - HostType::Path(data_dir.to_string_lossy().to_string()), - port, - ))? - .remove(0), - Mode::Tcp | Mode::TcpSsl => { - ResolvedTarget::SocketAddr(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port)) - } - }; - - Ok(Some(PostgresProcess { - child, - socket_address, - temp_dir, - })) -} - #[rstest] #[tokio::test] async fn test_auth_real( @@ -391,7 +35,7 @@ async fn test_auth_real( server_settings: Default::default(), }; - let client = postgres_process.socket_address.connect().await?; + let client = address(&postgres_process.socket_address).connect().await?; let ssl_requirement = match mode { Mode::TcpSsl => ConnectionSslRequirement::Required, @@ -426,7 +70,7 @@ async fn test_bad_password( server_settings: Default::default(), }; - let client = postgres_process.socket_address.connect().await?; + let client = address(&postgres_process.socket_address).connect().await?; let ssl_requirement = match mode { Mode::TcpSsl => ConnectionSslRequirement::Required, @@ -458,7 +102,7 @@ async fn test_bad_username( server_settings: Default::default(), }; - let client = postgres_process.socket_address.connect().await?; + let client = address(&postgres_process.socket_address).connect().await?; let ssl_requirement = match mode { Mode::TcpSsl => ConnectionSslRequirement::Required, @@ -490,7 +134,7 @@ async fn test_bad_database( server_settings: Default::default(), }; - let client = postgres_process.socket_address.connect().await?; + let client = address(&postgres_process.socket_address).connect().await?; let ssl_requirement = match mode { Mode::TcpSsl => ConnectionSslRequirement::Required, From 0bbb4e6601eb7a40f2a2cadfc16c68c6c6ec30c5 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Thu, 2 Jan 2025 16:14:19 -0800 Subject: [PATCH 2/7] Fix pg_constraint oid duplication (#8166) Multiple of our invested constraints were being given the same oid which caused problems. One specific problem was generating bogus pgdumps, which annoyingly mostly only showed up in inplace-upgrade tests: https://github.com/edgedb/edgedb/actions/runs/12442350513/job/34740389082?pr=8159 The problem was that pg_get_constraintdef was returning a constraint definition for the "wrong" object; the defining query would return 3 rows, and postgres silently returns the first. Fix this by adding some bits to the oid separate from the uuid of the link, and test that the fix works by putting the body of pg_get_constraintdef into a subquery. This hopefully will unblock #8159. --- edb/buildmeta.py | 2 +- edb/pgsql/metaschema.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/edb/buildmeta.py b/edb/buildmeta.py index 31a13a15ded..dd944f10e9e 100644 --- a/edb/buildmeta.py +++ b/edb/buildmeta.py @@ -60,7 +60,7 @@ # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. -EDGEDB_CATALOG_VERSION = 2024_12_17_00_00 +EDGEDB_CATALOG_VERSION = 2024_01_02_00_00 EDGEDB_MAJOR_VERSION = 7 diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 683e1d3809d..8e323b420be 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -6389,12 +6389,14 @@ def make_wrapper_view(name: str) -> trampoline.VersionedView: name=('edgedbsql', 'uuid_to_oid'), args=( ('id', 'uuid'), + # extra is two extra bits to throw into the oid, for now + ('extra', 'int4', '0'), ), returns=('oid',), volatility='immutable', text=""" SELECT ( - ('x' || substring(id::text, 2, 7))::bit(28)::bigint + ('x' || substring(id::text, 2, 7))::bit(28)::bigint*4 + extra + 40000)::oid; """ ) @@ -7194,7 +7196,9 @@ def make_wrapper_view(name: str) -> trampoline.VersionedView: -- foreign keys for object tables SELECT - edgedbsql_VER.uuid_to_oid(sl.id) as oid, + -- uuid_to_oid needs "extra" arg to disambiguate from the link table + -- keys below + edgedbsql_VER.uuid_to_oid(sl.id, 0) as oid, vt.table_name || '_fk_' || sl.name AS conname, edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, 'f'::"char" AS contype, @@ -7240,7 +7244,9 @@ def make_wrapper_view(name: str) -> trampoline.VersionedView: -- - single link with link properties (source & target), -- these constraints do not actually exist, so we emulate it entierly SELECT - edgedbsql_VER.uuid_to_oid(sp.id) AS oid, + -- uuid_to_oid needs "extra" arg to disambiguate from other + -- constraints using this pointer + edgedbsql_VER.uuid_to_oid(sp.id, spec.attnum) AS oid, vt.table_name || '_fk_' || spec.name AS conname, edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, 'f'::"char" AS contype, @@ -7856,6 +7862,10 @@ def construct_pg_view( returns=('text',), volatility='stable', text=r""" + -- Wrap in a subquery SELECT so that we get a clear failure + -- if something is broken and this returns multiple rows. + -- (By default it would silently return the first.) + SELECT ( SELECT CASE WHEN contype = 'p' THEN 'PRIMARY KEY(' || ( @@ -7868,7 +7878,6 @@ def construct_pg_view( SELECT attname FROM edgedbsql_VER.pg_attribute WHERE attrelid = conrelid AND attnum = ANY(conkey) - LIMIT 1 ) || '")' || ' REFERENCES "' || pn.nspname || '"."' || pc.relname || '"(id)' ELSE '' @@ -7878,6 +7887,7 @@ def construct_pg_view( LEFT JOIN edgedbsql_VER.pg_namespace pn ON pc.relnamespace = pn.oid WHERE con.oid = conid + ) """ ), trampoline.VersionedFunction( From 4472152ceb1866444f3f1f17b26842ac664b9ecf Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Thu, 2 Jan 2025 19:15:22 -0800 Subject: [PATCH 3/7] Fix SQL introspection after inplace upgrade (#8159) We need to refresh the views after the upgrade. Fixes #8155. --- edb/server/cluster.py | 14 ++++++++++---- edb/server/inplace_upgrade.py | 10 ++++++++++ tests/inplace-testing/test.sh | 3 +++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/edb/server/cluster.py b/edb/server/cluster.py index 6b9a55e1a31..6f4524bb06f 100644 --- a/edb/server/cluster.py +++ b/edb/server/cluster.py @@ -324,7 +324,11 @@ async def test() -> None: started = time.monotonic() await test() left -= (time.monotonic() - started) - if res := self._admin_query("SELECT ();", f"{max(1, int(left))}s"): + if res := self._admin_query( + "SELECT ();", + f"{max(1, int(left))}s", + check=False, + ): raise ClusterError( f'could not connect to edgedb-server ' f'within {timeout} seconds (exit code = {res})') from None @@ -333,6 +337,7 @@ def _admin_query( self, query: str, wait_until_available: str = "0s", + check: bool=True, ) -> int: args = [ "edgedb", @@ -350,12 +355,13 @@ def _admin_query( wait_until_available, query, ] - res = subprocess.call( + res = subprocess.run( args=args, - stdout=subprocess.DEVNULL, + check=check, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) - return res + return res.returncode async def set_test_config(self) -> None: self._admin_query(f''' diff --git a/edb/server/inplace_upgrade.py b/edb/server/inplace_upgrade.py index d07ca92bee6..86b9c4551b2 100644 --- a/edb/server/inplace_upgrade.py +++ b/edb/server/inplace_upgrade.py @@ -57,6 +57,7 @@ from edb.pgsql import common as pg_common from edb.pgsql import dbops +from edb.pgsql import metaschema from edb.pgsql import trampoline @@ -273,6 +274,15 @@ async def _upgrade_one( except Exception: raise + # Refresh the pg_catalog materialized views + current_block = dbops.PLTopBlock() + refresh = metaschema.generate_sql_information_schema_refresh( + backend_params.instance_params.version + ) + refresh.generate(current_block) + patch = current_block.to_string() + await ctx.conn.sql_execute(patch.encode('utf-8')) + new_local_spec = config.load_spec_from_schema( schema, only_exts=True, diff --git a/tests/inplace-testing/test.sh b/tests/inplace-testing/test.sh index f7265fc1e20..d3ac8012738 100755 --- a/tests/inplace-testing/test.sh +++ b/tests/inplace-testing/test.sh @@ -151,6 +151,9 @@ if $EDGEDB query 'create empty branch asdf'; then fi $EDGEDB query 'configure instance reset force_database_error' stop_server +if [ "$SAVE_TARBALLS" = 1 ]; then + tar cf "$DIR"-cooked2.tar "$DIR" +fi # Test! From 2349a5ea8dcacfed7497cf39f3a9c5b1b81c5b76 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Sat, 4 Jan 2025 00:21:22 -0500 Subject: [PATCH 4/7] test: use new server for stats tests (#8172) So that reset_query_stats() doesn't affect each other --- tests/test_edgeql_sys.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/test_edgeql_sys.py b/tests/test_edgeql_sys.py index 5909661e757..319fe92d7ae 100644 --- a/tests/test_edgeql_sys.py +++ b/tests/test_edgeql_sys.py @@ -38,12 +38,13 @@ async def _configure_track(self, option: str): async def _bad_query_for_stats(self): raise NotImplementedError - async def _test_sys_query_stats(self): + def _before_test_sys_query_stats(self): if self.backend_dsn: self.skipTest( "can't run query stats test when extension isn't present" ) + async def _test_sys_query_stats(self): stats_query = f''' with stats := ( select @@ -177,7 +178,15 @@ async def _bad_query_for_stats(self): await self.con.query(f'select {self.stats_magic_word}_NoSuchType') async def test_edgeql_sys_query_stats(self): - await self._test_sys_query_stats() + self._before_test_sys_query_stats() + async with tb.start_edgedb_server() as sd: + old_con = self.con + self.con = await sd.connect() + try: + await self._test_sys_query_stats() + finally: + await self.con.aclose() + self.con = old_con class TestSQLSys(tb.SQLQueryTestCase, TestQueryStatsMixin): @@ -215,4 +224,14 @@ async def _bad_query_for_stats(self): ) async def test_sql_sys_query_stats(self): - await self._test_sys_query_stats() + self._before_test_sys_query_stats() + async with tb.start_edgedb_server() as sd: + old_cons = self.con, self.scon + self.con = await sd.connect() + self.scon = await sd.connect_pg() + try: + await self._test_sys_query_stats() + finally: + await self.scon.close() + await self.con.aclose() + self.con, self.scon = old_cons From c25096b2a60db3c95266f3396179be90659b487f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Mon, 6 Jan 2025 18:34:45 +0100 Subject: [PATCH 5/7] Expand SQL docs (#8083) Co-authored-by: Scott Trinh --- docs/reference/sql_adapter.rst | 207 ++++++++++++++++++++++++++++----- 1 file changed, 180 insertions(+), 27 deletions(-) diff --git a/docs/reference/sql_adapter.rst b/docs/reference/sql_adapter.rst index a10960a466e..2cecdf8fa02 100644 --- a/docs/reference/sql_adapter.rst +++ b/docs/reference/sql_adapter.rst @@ -300,54 +300,207 @@ construct is mapped to PostgreSQL schema: - Aliases are not mapped to PostgreSQL schema. -- Globals are mapped to connection settings, prefixed with ``global``. - For example, a ``global default::username: str`` can be set using: - - .. code-block:: sql +.. versionadded:: 6.0 - SET "global default::username" TO 'Tom'``. + - Globals are mapped to connection settings, prefixed with ``global``. + For example, a ``global default::username: str`` can be accessed using: -- Access policies are applied to object type tables when setting - ``apply_access_policies_pg`` is set to ``true``. + .. code-block:: sql -- Mutation rewrites and triggers are applied to all DML commands. + SET "global default::username" TO 'Tom'``; + SHOW "global default::username"; + - Access policies are applied to object type tables when setting + ``apply_access_policies_pg`` is set to ``true``. + + - Mutation rewrites and triggers are applied to all DML commands. DML commands ============ -When using ``INSERT``, ``DELETE`` or ``UPDATE`` on any table, mutation rewrites -and triggers are applied. These commands do not have a straight-forward -translation to EdgeQL DML commands, but instead use the following mapping: +.. versionchanged:: _default + + Data Modification Language commands (``INSERT``, ``UPDATE``, ``DELETE``, ..) + are not supported in EdgeDB <6.0. + +.. versionchanged:: 6.0 + +.. versionadded:: 6.0 -- ``INSERT INTO "Foo"`` object table maps to ``insert Foo``, + When using ``INSERT``, ``DELETE`` or ``UPDATE`` on any table, mutation + rewrites and triggers are applied. These commands do not have a + straight-forward translation to EdgeQL DML commands, but instead use the + following mapping: -- ``INSERT INTO "Foo.keywords"`` link/property table maps to an - ``update Foo { keywords += ... }``, + - ``INSERT INTO "Foo"`` object table maps to ``insert Foo``, -- ``DELETE FROM "Foo"`` object table maps to ``delete Foo``, + - ``INSERT INTO "Foo.keywords"`` link/property table maps to an + ``update Foo { keywords += ... }``, -- ``DELETE FROM "Foo.keywords"`` link property/table maps to - ``update Foo { keywords -= ... }``, + - ``DELETE FROM "Foo"`` object table maps to ``delete Foo``, -- ``UPDATE "Foo"`` object table maps to ``update Foo set { ... }``, + - ``DELETE FROM "Foo.keywords"`` link property/table maps to + ``update Foo { keywords -= ... }``, -- ``UPDATE "Foo.keywords"`` is not supported. + - ``UPDATE "Foo"`` object table maps to ``update Foo set { ... }``, + + - ``UPDATE "Foo.keywords"`` is not supported. Connection settings =================== -SQL adapter supports a limited subset of PostgreSQL connection settings. -There are the following additionally connection settings: +SQL adapter supports most of PostgreSQL connection settings +(for example ``search_path``), in the same manner as plain PostgreSQL: + +.. code-block:: sql + + SET search_path TO my_module; + + SHOW search_path; + + RESET search_path; + +.. versionadded:: 6.0 + + In addition, there are the following EdgeDB-specific settings: + + - settings prefixed with ``"global "`` set the values of globals. + + Because SQL syntax allows only string, integer and float constants in + ``SET`` command, globals of other types such as ``datetime`` cannot be set + this way. + + .. code-block:: sql + + SET "global my_module::hello" TO 'world'; + + Special handling is in place to enable setting: + - ``bool`` types via integers 0 or 1), + - ``uuid`` types via hex-encoded strings. + + .. code-block:: sql + + SET "global my_module::current_user_id" + TO "592c62c6-73dd-4b7b-87ba-46e6d34ec171"; + SET "global my_module::is_admin" TO 1; + + To set globals of other types via SQL, it is recommended to change the + global to use one of the simple types instead, and use appropriate casts + where the global is used. + + + - ``allow_user_specified_id`` (default ``false``), + + - ``apply_access_policies_pg`` (default ``false``), + + Note that if ``allow_user_specified_id`` or ``apply_access_policies_pg`` are + unset, they default to configuration set by ``configure current database`` + EdgeQL command. + + +Introspection +============= + +The adapter emulates introspection schemas of PostgreSQL: ``information_schema`` +and ``pg_catalog``. + +Both schemas are not perfectly emulated, since they are quite large and +complicated stores of information, that also changed between versions of +PostgreSQL. + +Because of that, some tools might show objects that are not queryable or might +report problems when introspecting. In such cases, please report the problem on +GitHub so we can track the incompatibility down. + +Note that since the two information schemas are emulated, querying them may +perform worse compared to other tables in the database. As a result, tools like +``pg_dump`` and other introspection utilities might seem slower. + + +Locking +======= + +.. versionchanged:: _default + + SQL adapter does not support ``LOCK`` in EdgeDB <6.0. + +.. versionchanged:: 6.0 + +.. versionadded:: 6.0 + + SQL adapter supports LOCK command with the following limitations: + + - it cannot be used on tables that represent object types with access + properties or links of such objects, + - it cannot be used on tables that represent object types that have child + types extending them. + +Query cache +=========== + +An SQL query is issued to EdgeDB, it is compiled to an internal SQL query, which +is then issued to the backing PostgreSQL instance. The compiled query is then +cached, so each following issue of the same query will not perform any +compilation, but just pass through the cached query. + +.. versionadded:: 6.0 + + Additionally, most queries are "normalized" before compilation. This process + extracts constant values and replaces them by internal query parameters. + This allows sharing of compilation cache between queries that differ in + only constant values. This process is totally opaque and is fully handled by + EdgeDB. For example: + + .. code-block:: sql + + SELECT $1, 42; + + ... is normalized to: + + .. code-block:: sql + + SELECT $1, $2; + + This way, when a similar query is issued to EdgeDB: + + .. code-block:: sql + + SELECT $1, 500; + + ... it normalizes to the same query as before, so it can reuse the query + cache. + + Note that normalization process does not (yet) remove any whitespace, so + queries ``SELECT 1;`` and ``SELECT 1 ;`` are compiled separately. + + +Known limitations +================= + +Following SQL statements are not supported: + +- ``CREATE``, ``ALTER``, ``DROP``, + +- ``TRUNCATE``, ``COMMENT``, ``SECURITY LABEL``, ``IMPORT FOREIGN SCHEMA``, + +- ``GRANT``, ``REVOKE``, + +- ``OPEN``, ``FETCH``, ``MOVE``, ``CLOSE``, ``DECLARE``, ``RETURN``, + +- ``CHECKPOINT``, ``DISCARD``, ``CALL``, + +- ``REINDEX``, ``VACUUM``, ``CLUSTER``, ``REFRESH MATERIALIZED VIEW``, + +- ``LISTEN``, ``UNLISTEN``, ``NOTIFY``, + +- ``LOAD``. -- ``allow_user_specified_id`` (default ``false``), -- ``apply_access_policies_pg`` (default ``false``), -- settings prefixed with ``"global "`` can use used to set values of globals. +Following functions are not supported: -Note that if ``allow_user_specified_id`` or ``apply_access_policies_pg`` are -unset, they default to configuration set by ``configure current database`` -EdgeQL command. +- ``set_config``, +- ``pg_filenode_relation``, +- most of system administration functions. Example: gradual transition from ORMs to EdgeDB From 1abb068f22e15200f6fed0e7ad5a7db79e6001a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Mon, 6 Jan 2025 20:28:22 +0100 Subject: [PATCH 6/7] Fix SQL settings_to_str for enums and bools (#8174) Closes #8147 --- edb/server/pgcon/pgcon.pxd | 2 +- edb/server/pgcon/pgcon.pyx | 182 +++++++++++++++++++++++++++++---- edb/server/protocol/pg_ext.pyx | 2 +- tests/test_sql_query.py | 51 +++++++++ 4 files changed, 217 insertions(+), 20 deletions(-) diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index 93da530be39..bc7883ca340 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -190,4 +190,4 @@ cdef class PGConnection: cdef inline str get_tenant_label(self) cpdef set_stmt_cache_size(self, int maxsize) -cdef setting_to_sql(setting) +cdef setting_to_sql(name, setting) diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 241e7be2b67..43201ab7ed9 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1824,7 +1824,9 @@ cdef class PGConnection: msg_buf = WriteBuffer.new_message(b'D') msg_buf.write_int16(1) # number of column values setting = dbv.current_fe_settings()[setting_name] - msg_buf.write_len_prefixed_utf8(setting_to_sql(setting)) + msg_buf.write_len_prefixed_utf8( + setting_to_sql(setting_name, setting) + ) buf.write_buffer(msg_buf.end_message()) # CommandComplete @@ -3015,28 +3017,172 @@ cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message()) cdef EdegDBCodecContext DEFAULT_CODEC_CONTEXT = EdegDBCodecContext() +# Settings that are enums or bools and should not be quoted. +# Can be retrived from PostgreSQL with: +# SELECt name FROM pg_catalog.pg_settings WHERE vartype IN ('enum', 'bool'); +cdef set ENUM_SETTINGS = { + 'allow_alter_system', + 'allow_in_place_tablespaces', + 'allow_system_table_mods', + 'archive_mode', + 'array_nulls', + 'autovacuum', + 'backslash_quote', + 'bytea_output', + 'check_function_bodies', + 'client_min_messages', + 'compute_query_id', + 'constraint_exclusion', + 'data_checksums', + 'data_sync_retry', + 'debug_assertions', + 'debug_logical_replication_streaming', + 'debug_parallel_query', + 'debug_pretty_print', + 'debug_print_parse', + 'debug_print_plan', + 'debug_print_rewritten', + 'default_toast_compression', + 'default_transaction_deferrable', + 'default_transaction_isolation', + 'default_transaction_read_only', + 'dynamic_shared_memory_type', + 'edb_stat_statements.save', + 'edb_stat_statements.track', + 'edb_stat_statements.track_planning', + 'edb_stat_statements.track_utility', + 'enable_async_append', + 'enable_bitmapscan', + 'enable_gathermerge', + 'enable_group_by_reordering', + 'enable_hashagg', + 'enable_hashjoin', + 'enable_incremental_sort', + 'enable_indexonlyscan', + 'enable_indexscan', + 'enable_material', + 'enable_memoize', + 'enable_mergejoin', + 'enable_nestloop', + 'enable_parallel_append', + 'enable_parallel_hash', + 'enable_partition_pruning', + 'enable_partitionwise_aggregate', + 'enable_partitionwise_join', + 'enable_presorted_aggregate', + 'enable_seqscan', + 'enable_sort', + 'enable_tidscan', + 'escape_string_warning', + 'event_triggers', + 'exit_on_error', + 'fsync', + 'full_page_writes', + 'geqo', + 'gss_accept_delegation', + 'hot_standby', + 'hot_standby_feedback', + 'huge_pages', + 'huge_pages_status', + 'icu_validation_level', + 'ignore_checksum_failure', + 'ignore_invalid_pages', + 'ignore_system_indexes', + 'in_hot_standby', + 'integer_datetimes', + 'intervalstyle', + 'jit', + 'jit_debugging_support', + 'jit_dump_bitcode', + 'jit_expressions', + 'jit_profiling_support', + 'jit_tuple_deforming', + 'krb_caseins_users', + 'lo_compat_privileges', + 'log_checkpoints', + 'log_connections', + 'log_disconnections', + 'log_duration', + 'log_error_verbosity', + 'log_executor_stats', + 'log_hostname', + 'log_lock_waits', + 'log_min_error_statement', + 'log_min_messages', + 'log_parser_stats', + 'log_planner_stats', + 'log_recovery_conflict_waits', + 'log_replication_commands', + 'log_statement', + 'log_statement_stats', + 'log_truncate_on_rotation', + 'logging_collector', + 'parallel_leader_participation', + 'password_encryption', + 'plan_cache_mode', + 'quote_all_identifiers', + 'recovery_init_sync_method', + 'recovery_prefetch', + 'recovery_target_action', + 'recovery_target_inclusive', + 'remove_temp_files_after_crash', + 'restart_after_crash', + 'row_security', + 'send_abort_for_crash', + 'send_abort_for_kill', + 'session_replication_role', + 'shared_memory_type', + 'ssl', + 'ssl_max_protocol_version', + 'ssl_min_protocol_version', + 'ssl_passphrase_command_supports_reload', + 'ssl_prefer_server_ciphers', + 'standard_conforming_strings', + 'stats_fetch_consistency', + 'summarize_wal', + 'sync_replication_slots', + 'synchronize_seqscans', + 'synchronous_commit', + 'syslog_facility', + 'syslog_sequence_numbers', + 'syslog_split_messages', + 'trace_connection_negotiation', + 'trace_notify', + 'trace_sort', + 'track_activities', + 'track_commit_timestamp', + 'track_counts', + 'track_functions', + 'track_io_timing', + 'track_wal_io_timing', + 'transaction_deferrable', + 'transaction_isolation', + 'transaction_read_only', + 'transform_null_equals', + 'update_process_title', + 'wal_compression', + 'wal_init_zero', + 'wal_level', + 'wal_log_hints', + 'wal_receiver_create_temp_slot', + 'wal_recycle', + 'wal_sync_method', + 'xmlbinary', + 'xmloption', + 'zero_damaged_pages', +} + + +cdef setting_to_sql(name, setting): + is_enum = name.lower() in ENUM_SETTINGS -cdef setting_to_sql(setting): assert typeutils.is_container(setting) - return ', '.join(setting_val_to_sql(v) for v in setting) - - -cdef set NON_QUOTABLE_STRINGS = { - 'repeatable read', - 'read committed', - 'read uncommitted', - 'off', - 'on', - 'yes', - 'no', - 'true', - 'false', -} + return ', '.join(setting_val_to_sql(v, is_enum) for v in setting) -cdef inline str setting_val_to_sql(val: str | int | float): +cdef inline str setting_val_to_sql(val: str | int | float, is_enum: bool): if isinstance(val, str): - if val in NON_QUOTABLE_STRINGS: + if is_enum: # special case: no quoting return val # quote as identifier diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index cf5a040359b..59e699c36fd 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -355,7 +355,7 @@ cdef class ConnectionView: return self._session_state_db_cache[1] rv = json.dumps({ - key: setting_to_sql(val) for key, val in self._settings.items() + key: setting_to_sql(key, val) for key, val in self._settings.items() }).encode("utf-8") self._session_state_db_cache = (self._settings, rv) return rv diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index 8bf94846b72..bf799d40f5e 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -3015,6 +3015,57 @@ async def are_policies_applied() -> bool: # setting cleanup not needed, since with end with the None, None + async def test_sql_query_set_05(self): + # IntervalStyle + + await self.scon.execute('SET IntervalStyle TO ISO_8601;') + [[res]] = await self.squery_values( + "SELECT '2 years 15 months 100 weeks 99 hours'::interval::text;" + ) + self.assertEqual(res, 'P3Y3M700DT99H') + + await self.scon.execute('SET IntervalStyle TO postgres_verbose;') + [[res]] = await self.squery_values( + "SELECT '2 years 15 months 100 weeks 99 hours'::interval::text;" + ) + self.assertEqual(res, '@ 3 years 3 mons 700 days 99 hours') + + await self.scon.execute('SET IntervalStyle TO sql_standard;') + [[res]] = await self.squery_values( + "SELECT '2 years 15 months 100 weeks 99 hours'::interval::text;" + ) + self.assertEqual(res, '+3-3 +700 +99:00:00') + + async def test_sql_query_set_06(self): + # bytea_output + + await self.scon.execute('SET bytea_output TO hex') + [[res]] = await self.squery_values( + "SELECT '\\x01abcdef01'::bytea::text" + ) + self.assertEqual(res, r'\x01abcdef01') + + await self.scon.execute('SET bytea_output TO escape') + [[res]] = await self.squery_values( + "SELECT '\\x01abcdef01'::bytea::text" + ) + self.assertEqual(res, r'\001\253\315\357\001') + + async def test_sql_query_set_07(self): + # enable_memoize + + await self.scon.execute('SET enable_memoize TO ye') + [[res]] = await self.squery_values( + "SELECT 'hello'" + ) + self.assertEqual(res, 'hello') + + await self.scon.execute('SET enable_memoize TO off') + [[res]] = await self.squery_values( + "SELECT 'hello'" + ) + self.assertEqual(res, 'hello') + @test.skip( 'blocking the connection causes other tests which trigger a ' 'PostgreSQL error to encounter a InternalServerError and close ' From c4e5e4ce55898b1ce080fc0aa07224aefd2e48d7 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 6 Jan 2025 13:09:14 -0700 Subject: [PATCH 7/7] Split protocol generation into a new `db_proto` crate (#8165) This attempts to clean up the automatic codegen for the database protocols by moving it into a new `db_proto` crate. Note that this comes with some additional complexity -- given Rust's orphan impl (https://github.com/Ixrec/rust-orphan-rules) rules, some of our tricks to ensure constant evaluation for buffers become a lot more complex. This is probably the furthest we can take a macro-based approach without having to dive into proc_macros, but the code is somewhat cleaner after this refactoring. We also now support Array<> with u8, u16, i16, u32, i32 lengths, as well as ZTArray<> for all basic types. All basic types also support fixed-sized arrays. --- Cargo.lock | 12 + Cargo.toml | 2 + rust/auth/src/scram.rs | 2 +- rust/db_proto/Cargo.toml | 18 + rust/db_proto/README.md | 4 + .../src/protocol => db_proto/src}/arrays.rs | 245 +++--- .../src/protocol => db_proto/src}/buffer.rs | 5 +- rust/db_proto/src/datatypes.rs | 558 +++++++++++++ rust/db_proto/src/field_access.rs | 569 +++++++++++++ .../src/protocol => db_proto/src}/gen.rs | 236 +++--- rust/db_proto/src/lib.rs | 218 +++++ .../src}/message_group.rs | 31 +- rust/db_proto/src/test_protocol.rs | 140 ++++ .../src/protocol => db_proto/src}/writer.rs | 0 rust/pgrust/Cargo.toml | 2 +- rust/pgrust/src/connection/conn.rs | 12 +- rust/pgrust/src/connection/flow.rs | 21 +- rust/pgrust/src/connection/mod.rs | 7 +- rust/pgrust/src/connection/raw_conn.rs | 3 +- rust/pgrust/src/errors/mod.rs | 2 +- .../src/handshake/client_state_machine.rs | 3 +- rust/pgrust/src/handshake/edgedb_server.rs | 6 +- .../src/handshake/server_state_machine.rs | 7 +- rust/pgrust/src/protocol/datatypes.rs | 788 ------------------ rust/pgrust/src/protocol/definition.rs | 740 ---------------- rust/pgrust/src/protocol/edgedb.rs | 4 +- rust/pgrust/src/protocol/mod.rs | 160 +--- rust/pgrust/src/protocol/postgres.rs | 3 +- rust/pgrust/src/python.rs | 15 +- rust/pgrust/tests/query_real_postgres.rs | 2 +- 30 files changed, 1828 insertions(+), 1987 deletions(-) create mode 100644 rust/db_proto/Cargo.toml create mode 100644 rust/db_proto/README.md rename rust/{pgrust/src/protocol => db_proto/src}/arrays.rs (53%) rename rust/{pgrust/src/protocol => db_proto/src}/buffer.rs (98%) create mode 100644 rust/db_proto/src/datatypes.rs create mode 100644 rust/db_proto/src/field_access.rs rename rust/{pgrust/src/protocol => db_proto/src}/gen.rs (71%) create mode 100644 rust/db_proto/src/lib.rs rename rust/{pgrust/src/protocol => db_proto/src}/message_group.rs (85%) create mode 100644 rust/db_proto/src/test_protocol.rs rename rust/{pgrust/src/protocol => db_proto/src}/writer.rs (100%) delete mode 100644 rust/pgrust/src/protocol/datatypes.rs delete mode 100644 rust/pgrust/src/protocol/definition.rs diff --git a/Cargo.lock b/Cargo.lock index d550dc2d969..081ff2ee5cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -517,6 +517,17 @@ dependencies = [ "typenum", ] +[[package]] +name = "db_proto" +version = "0.1.0" +dependencies = [ + "derive_more", + "paste", + "pretty_assertions", + "thiserror 1.0.63", + "uuid", +] + [[package]] name = "derive_more" version = "1.0.0" @@ -1588,6 +1599,7 @@ dependencies = [ "captive_postgres", "clap", "clap_derive", + "db_proto", "derive_more", "futures", "gel_auth", diff --git a/Cargo.toml b/Cargo.toml index 26ae534004e..50a653d221e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "rust/auth", "rust/captive_postgres", "rust/conn_pool", + "rust/db_proto", "rust/pgrust", "rust/http", "rust/pyo3_util" @@ -21,6 +22,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] } gel_auth = { path = "rust/auth" } +db_proto = { path = "rust/db_proto" } captive_postgres = { path = "rust/captive_postgres" } conn_pool = { path = "rust/conn_pool" } pgrust = { path = "rust/pgrust" } diff --git a/rust/auth/src/scram.rs b/rust/auth/src/scram.rs index bcf055d9aae..8c1e8ace725 100644 --- a/rust/auth/src/scram.rs +++ b/rust/auth/src/scram.rs @@ -6,7 +6,7 @@ //! protocols like Postgres and SASL to enhance security against common attacks //! such as replay and man-in-the-middle attacks. //! -//! https://en.wikipedia.org/wiki/Salted_Challenge_Response_Authentication_Mechanism +//! //! //! ## Limitations of this implementation //! diff --git a/rust/db_proto/Cargo.toml b/rust/db_proto/Cargo.toml new file mode 100644 index 00000000000..80d8bca57f0 --- /dev/null +++ b/rust/db_proto/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "db_proto" +version = "0.1.0" +license = "MIT/Apache-2.0" +authors = ["MagicStack Inc. "] +edition = "2021" + +[lints] +workspace = true + +[dependencies] +thiserror = "1" +paste = "1" +derive_more = { version = "1", features = ["full"] } +uuid = "1" + +[dev-dependencies] +pretty_assertions = "1.2.0" diff --git a/rust/db_proto/README.md b/rust/db_proto/README.md new file mode 100644 index 00000000000..aae37968447 --- /dev/null +++ b/rust/db_proto/README.md @@ -0,0 +1,4 @@ +# db_proto + +This is a crate that makes parsing and serializing of PostgreSQL-like protocols +(ie: Postgres itself, as well as Gel/EdgeDB) easier. diff --git a/rust/pgrust/src/protocol/arrays.rs b/rust/db_proto/src/arrays.rs similarity index 53% rename from rust/pgrust/src/protocol/arrays.rs rename to rust/db_proto/src/arrays.rs index 9e7811ca049..c65a5118f82 100644 --- a/rust/pgrust/src/protocol/arrays.rs +++ b/rust/db_proto/src/arrays.rs @@ -1,9 +1,11 @@ #![allow(private_bounds)] -use super::{Enliven, FieldAccessArray, FixedSize, Meta, MetaRelation}; + +use super::{Enliven, FieldAccessArray, FixedSize, Meta, MetaRelation, ParseError}; pub use std::marker::PhantomData; pub mod meta { pub use super::ArrayMeta as Array; + pub use super::FixedArrayMeta as FixedArray; pub use super::ZTArrayMeta as ZTArray; } @@ -15,7 +17,7 @@ pub struct ZTArray<'a, T: FieldAccessArray> { /// Metaclass for [`ZTArray`]. pub struct ZTArrayMeta { - pub(crate) _phantom: PhantomData, + pub _phantom: PhantomData, } impl Meta for ZTArrayMeta { @@ -95,6 +97,38 @@ impl<'a, T: FieldAccessArray> Iterator for ZTArrayIter<'a, T> { } } +impl FieldAccessArray for ZTArrayMeta { + const META: &'static dyn Meta = &ZTArrayMeta:: { + _phantom: PhantomData, + }; + fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = 1; + loop { + if buf.is_empty() { + return Err(ParseError::TooShort); + } + if buf[0] == 0 { + return Ok(size); + } + let elem_size = match T::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + } + fn extract(buf: &[u8]) -> Result, ParseError> { + Ok(ZTArray::new(buf)) + } + fn copy_to_buf(buf: &mut crate::BufWriter, value: &&[::ForBuilder<'_>]) { + for elem in *value { + T::copy_to_buf(buf, elem); + } + buf.write_u8(0); + } +} + /// Inflated version of a length-specified array with zero-copy iterator access. pub struct Array<'a, L, T: FieldAccessArray> { _phantom: PhantomData<(L, T)>, @@ -104,7 +138,7 @@ pub struct Array<'a, L, T: FieldAccessArray> { /// Metaclass for [`Array`]. pub struct ArrayMeta { - pub(crate) _phantom: PhantomData<(L, T)>, + pub _phantom: PhantomData<(L, T)>, } impl Meta for ArrayMeta { @@ -121,7 +155,7 @@ impl Meta for ArrayMeta { impl Enliven for ArrayMeta where - T: FieldAccessArray, + T: FieldAccessArray + Enliven, { type WithLifetime<'a> = Array<'a, L, T>; type ForMeasure<'a> = &'a [::ForMeasure<'a>]; @@ -203,130 +237,6 @@ impl<'a, T: FieldAccessArray> Iterator for ArrayIter<'a, T> { } } -/// Definate array accesses for inflated, strongly-typed arrays of both -/// zero-terminated and length-delimited types. -macro_rules! array_access { - ($ty:ty) => { - $crate::protocol::arrays::array_access!($ty | u8 i16 i32); - }; - ($ty:ty | $($len:ty)*) => { - $( - #[allow(unused)] - impl $crate::protocol::FieldAccess<$crate::protocol::meta::Array<$len, $ty>> { - pub const fn meta() -> &'static dyn $crate::protocol::Meta { - &$crate::protocol::meta::Array::<$len, $ty> { _phantom: std::marker::PhantomData } - } - #[inline] - pub const fn size_of_field_at(mut buf: &[u8]) -> Result { - let mut size = std::mem::size_of::<$len>(); - let mut len = match $crate::protocol::FieldAccess::<$len>::extract(buf) { - Ok(n) => n, - Err(e) => return Err(e), - }; - #[allow(unused_comparisons)] - if len < 0 { - return Err($crate::protocol::ParseError::InvalidData); - } - buf = buf.split_at(size).1; - loop { - if len <= 0 { - break; - } - len -= 1; - let elem_size = match $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf) { - Ok(n) => n, - Err(e) => return Err(e), - }; - buf = buf.split_at(elem_size).1; - size += elem_size; - } - Ok(size) - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result<$crate::protocol::Array<'_, $len, $ty>, $crate::protocol::ParseError> { - match $crate::protocol::FieldAccess::<$len>::extract(buf) { - Ok(len) => Ok($crate::protocol::Array::new(buf.split_at(std::mem::size_of::<$len>()).1, len as u32)), - Err(e) => Err(e) - } - } - #[inline] - pub const fn measure<'a>(buffer: &'a[<$ty as $crate::protocol::Enliven>::ForMeasure<'a>]) -> usize { - let mut size = std::mem::size_of::<$len>(); - let mut index = 0; - loop { - if index + 1 > buffer.len() { - break; - } - let item = &buffer[index]; - size += $crate::protocol::FieldAccess::<$ty>::measure(item); - index += 1; - } - size - } - #[inline(always)] - pub fn copy_to_buf<'a>(buf: &mut $crate::protocol::writer::BufWriter, value: &'a[<$ty as $crate::protocol::Enliven>::ForBuilder<'a>]) { - buf.write(&<$len>::to_be_bytes(value.len() as _)); - for elem in value { - $crate::protocol::FieldAccess::<$ty>::copy_to_buf_ref(buf, elem); - } - } - - } - )* - - #[allow(unused)] - impl $crate::protocol::FieldAccess<$crate::protocol::meta::ZTArray<$ty>> { - pub const fn meta() -> &'static dyn $crate::protocol::Meta { - &$crate::protocol::meta::ZTArray::<$ty> { _phantom: std::marker::PhantomData } - } - #[inline] - pub const fn size_of_field_at(mut buf: &[u8]) -> Result { - let mut size = 1; - loop { - if buf.is_empty() { - return Err($crate::protocol::ParseError::TooShort); - } - if buf[0] == 0 { - return Ok(size); - } - let elem_size = match $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf) { - Ok(n) => n, - Err(e) => return Err(e), - }; - buf = buf.split_at(elem_size).1; - size += elem_size; - } - } - #[inline(always)] - pub const fn extract(mut buf: &[u8]) -> Result<$crate::protocol::ZTArray<$ty>, $crate::protocol::ParseError> { - Ok($crate::protocol::ZTArray::new(buf)) - } - #[inline] - pub const fn measure<'a>(buffer: &'a[<$ty as $crate::protocol::Enliven>::ForMeasure<'a>]) -> usize { - let mut size = 1; - let mut index = 0; - loop { - if index + 1 > buffer.len() { - break; - } - let item = &buffer[index]; - size += $crate::protocol::FieldAccess::<$ty>::measure(item); - index += 1; - } - size - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut $crate::protocol::writer::BufWriter, value: &[<$ty as $crate::protocol::Enliven>::ForBuilder<'_>]) { - for elem in value { - $crate::protocol::FieldAccess::<$ty>::copy_to_buf_ref(buf, elem); - } - buf.write_u8(0); - } - } - }; -} -pub(crate) use array_access; - // Arrays of type [`u8`] are special-cased to return a slice of bytes. impl AsRef<[u8]> for Array<'_, T, u8> { fn as_ref(&self) -> &[u8] { @@ -350,3 +260,82 @@ impl<'a, L: TryInto, T: FixedSize + FieldAccessArray> Array<'a, L, T> { } } } + +impl FieldAccessArray + for ArrayMeta +where + for<'a> L::ForBuilder<'a>: TryFrom, + for<'a> L::WithLifetime<'a>: TryInto, +{ + const META: &'static dyn Meta = &ArrayMeta:: { + _phantom: PhantomData, + }; + fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = std::mem::size_of::(); + let len = match L::extract(buf) { + Ok(n) => n.try_into(), + Err(e) => return Err(e), + }; + #[allow(unused_comparisons)] + let Ok(mut len) = len + else { + return Err(ParseError::InvalidData); + }; + buf = buf.split_at(size).1; + loop { + if len == 0 { + break; + } + len -= 1; + let elem_size = match T::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + Ok(size) + } + fn extract(buf: &[u8]) -> Result, ParseError> { + let len = match L::extract(buf) { + Ok(len) => len.try_into(), + Err(e) => { + return Err(e); + } + }; + let Ok(len) = len else { + return Err(ParseError::InvalidData); + }; + Ok(Array::new( + buf.split_at(std::mem::size_of::()).1, + len as _, + )) + } + fn copy_to_buf(buf: &mut crate::BufWriter, value: &&[::ForBuilder<'_>]) { + let Ok(len) = L::ForBuilder::try_from(value.len()) else { + panic!("Array length out of bounds"); + }; + L::copy_to_buf(buf, &len); + for elem in *value { + T::copy_to_buf(buf, elem); + } + } +} + +pub struct FixedArrayMeta { + pub _phantom: PhantomData, +} + +impl Meta for FixedArrayMeta { + fn name(&self) -> &'static str { + "FixedArray" + } + + fn fixed_length(&self) -> Option { + Some(S) + } + + fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { + &[(MetaRelation::Item, ::META)] + } +} diff --git a/rust/pgrust/src/protocol/buffer.rs b/rust/db_proto/src/buffer.rs similarity index 98% rename from rust/pgrust/src/protocol/buffer.rs rename to rust/db_proto/src/buffer.rs index 6837407b41a..462978388c2 100644 --- a/rust/pgrust/src/protocol/buffer.rs +++ b/rust/db_proto/src/buffer.rs @@ -154,9 +154,10 @@ impl StructBuffer { #[cfg(test)] mod tests { + use crate::{Encoded, ParseError}; + use super::StructBuffer; - use crate::protocol::postgres::{builder, data::*, meta}; - use crate::protocol::*; + use crate::test_protocol::{builder, data::*, meta}; /// Create a test data buffer containing three messages fn test_data() -> (Vec, Vec) { diff --git a/rust/db_proto/src/datatypes.rs b/rust/db_proto/src/datatypes.rs new file mode 100644 index 00000000000..a2f289fdd4c --- /dev/null +++ b/rust/db_proto/src/datatypes.rs @@ -0,0 +1,558 @@ +use std::{marker::PhantomData, str::Utf8Error}; + +pub use uuid::Uuid; + +use crate::{ + declare_field_access, declare_field_access_fixed_size, writer::BufWriter, Enliven, FieldAccess, + FieldAccessArray, Meta, ParseError, +}; + +pub mod meta { + pub use super::BasicMeta as Basic; + pub use super::EncodedMeta as Encoded; + pub use super::LStringMeta as LString; + pub use super::LengthMeta as Length; + pub use super::RestMeta as Rest; + pub use super::UuidMeta as Uuid; + pub use super::ZTStringMeta as ZTString; +} + +/// Represents the remainder of data in a message. +#[derive(Debug, PartialEq, Eq)] +pub struct Rest<'a> { + buf: &'a [u8], +} + +declare_field_access! { + Meta = RestMeta, + Inflated = Rest<'a>, + Measure = &'a [u8], + Builder = &'a [u8], + + pub const fn meta() -> &'static dyn Meta { + &RestMeta {} + } + + pub const fn size_of_field_at(buf: &[u8]) -> Result { + Ok(buf.len()) + } + + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + Ok(Rest { buf }) + } + + pub const fn measure(buf: &[u8]) -> usize { + buf.len() + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &[u8]) { + buf.write(value) + } + + pub const fn constant(_constant: usize) -> Rest<'static> { + panic!("Constants unsupported for this data type") + } +} + +pub struct RestMeta {} +impl Meta for RestMeta { + fn name(&self) -> &'static str { + "Rest" + } +} + +impl<'a> Rest<'a> {} + +impl<'a> AsRef<[u8]> for Rest<'a> { + fn as_ref(&self) -> &[u8] { + self.buf + } +} + +impl<'a> std::ops::Deref for Rest<'a> { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + self.buf + } +} + +impl PartialEq<[u8]> for Rest<'_> { + fn eq(&self, other: &[u8]) -> bool { + self.buf == other + } +} + +impl PartialEq<&[u8; N]> for Rest<'_> { + fn eq(&self, other: &&[u8; N]) -> bool { + self.buf == *other + } +} + +impl PartialEq<&[u8]> for Rest<'_> { + fn eq(&self, other: &&[u8]) -> bool { + self.buf == *other + } +} + +/// A zero-terminated string. +#[allow(unused)] +pub struct ZTString<'a> { + buf: &'a [u8], +} + +declare_field_access!( + Meta = ZTStringMeta, + Inflated = ZTString<'a>, + Measure = &'a str, + Builder = &'a str, + + pub const fn meta() -> &'static dyn Meta { + &ZTStringMeta {} + } + + pub const fn size_of_field_at(buf: &[u8]) -> Result { + let mut i = 0; + loop { + if i >= buf.len() { + return Err(ParseError::TooShort); + } + if buf[i] == 0 { + return Ok(i + 1); + } + i += 1; + } + } + + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + let buf = buf.split_at(buf.len() - 1).0; + Ok(ZTString { buf }) + } + + pub const fn measure(buf: &str) -> usize { + buf.len() + 1 + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { + buf.write(value.as_bytes()); + buf.write_u8(0); + } + + pub const fn constant(_constant: usize) -> ZTString<'static> { + panic!("Constants unsupported for this data type") + } +); + +pub struct ZTStringMeta {} +impl Meta for ZTStringMeta { + fn name(&self) -> &'static str { + "ZTString" + } +} + +impl std::fmt::Debug for ZTString<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + String::from_utf8_lossy(self.buf).fmt(f) + } +} + +impl<'a> ZTString<'a> { + pub fn to_owned(&self) -> Result { + std::str::from_utf8(self.buf).map(|s| s.to_owned()) + } + + pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.buf) + } + + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + String::from_utf8_lossy(self.buf) + } + + pub fn to_bytes(&self) -> &[u8] { + self.buf + } +} + +impl PartialEq for ZTString<'_> { + fn eq(&self, other: &Self) -> bool { + self.buf == other.buf + } +} +impl Eq for ZTString<'_> {} + +impl PartialEq for ZTString<'_> { + fn eq(&self, other: &str) -> bool { + self.buf == other.as_bytes() + } +} + +impl PartialEq<&str> for ZTString<'_> { + fn eq(&self, other: &&str) -> bool { + self.buf == other.as_bytes() + } +} + +impl<'a> TryInto<&'a str> for ZTString<'a> { + type Error = Utf8Error; + fn try_into(self) -> Result<&'a str, Self::Error> { + std::str::from_utf8(self.buf) + } +} + +/// A length-prefixed string. +#[allow(unused)] +pub struct LString<'a> { + buf: &'a [u8], +} + +declare_field_access!( + Meta = LStringMeta, + Inflated = LString<'a>, + Measure = &'a str, + Builder = &'a str, + + pub const fn meta() -> &'static dyn Meta { + &LStringMeta {} + } + + pub const fn size_of_field_at(buf: &[u8]) -> Result { + if buf.len() < 4 { + return Err(ParseError::TooShort); + } + let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + Ok(4 + len) + } + + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + if buf.len() < 4 { + return Err(ParseError::TooShort); + } + let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + if buf.len() < 4 + len { + return Err(ParseError::TooShort); + } + Ok(LString { + buf: buf.split_at(4).1, + }) + } + + pub const fn measure(buf: &str) -> usize { + 4 + buf.len() + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { + let len = value.len() as u32; + buf.write(&len.to_be_bytes()); + buf.write(value.as_bytes()); + } + + pub const fn constant(_constant: usize) -> LString<'static> { + panic!("Constants unsupported for this data type") + } +); + +pub struct LStringMeta {} +impl Meta for LStringMeta { + fn name(&self) -> &'static str { + "LString" + } +} + +impl std::fmt::Debug for LString<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + String::from_utf8_lossy(self.buf).fmt(f) + } +} + +impl<'a> LString<'a> { + pub fn to_owned(&self) -> Result { + std::str::from_utf8(self.buf).map(|s| s.to_owned()) + } + + pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.buf) + } + + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + String::from_utf8_lossy(self.buf) + } + + pub fn to_bytes(&self) -> &[u8] { + self.buf + } +} + +impl PartialEq for LString<'_> { + fn eq(&self, other: &Self) -> bool { + self.buf == other.buf + } +} +impl Eq for LString<'_> {} + +impl PartialEq for LString<'_> { + fn eq(&self, other: &str) -> bool { + self.buf == other.as_bytes() + } +} + +impl PartialEq<&str> for LString<'_> { + fn eq(&self, other: &&str) -> bool { + self.buf == other.as_bytes() + } +} + +impl<'a> TryInto<&'a str> for LString<'a> { + type Error = Utf8Error; + fn try_into(self) -> Result<&'a str, Self::Error> { + std::str::from_utf8(self.buf) + } +} + +declare_field_access_fixed_size! { + Meta = UuidMeta, + Inflated = Uuid, + Measure = Uuid, + Builder = Uuid, + Size = 16, + Zero = Uuid::nil(), + + pub const fn meta() -> &'static dyn Meta { + &UuidMeta {} + } + + pub const fn extract(buf: &[u8; 16]) -> Result { + Ok(Uuid::from_u128(::from_be_bytes(*buf))) + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &Uuid) { + buf.write(value.as_bytes().as_slice()) + } + + pub const fn constant(_constant: usize) -> Uuid { + panic!("Constants unsupported for this data type") + } +} + +pub struct UuidMeta {} +impl Meta for UuidMeta { + fn name(&self) -> &'static str { + "Uuid" + } +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +/// An encoded row value. +pub enum Encoded<'a> { + #[default] + Null, + Value(&'a [u8]), +} + +impl<'a> Encoded<'a> { + pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + match self { + Encoded::Null => "".into(), + Encoded::Value(value) => String::from_utf8_lossy(value), + } + } +} + +impl<'a> AsRef> for Encoded<'a> { + fn as_ref(&self) -> &Encoded<'a> { + self + } +} + +declare_field_access! { + Meta = EncodedMeta, + Inflated = Encoded<'a>, + Measure = Encoded<'a>, + Builder = Encoded<'a>, + + pub const fn meta() -> &'static dyn Meta { + &EncodedMeta {} + } + + pub const fn size_of_field_at(buf: &[u8]) -> Result { + const N: usize = std::mem::size_of::(); + if let Some(len) = buf.first_chunk::() { + let len = i32::from_be_bytes(*len); + if len == -1 { + Ok(N) + } else if len < 0 { + Err(ParseError::InvalidData) + } else if buf.len() < len as usize + N { + Err(ParseError::TooShort) + } else { + Ok(len as usize + N) + } + } else { + Err(ParseError::TooShort) + } + } + + pub const fn extract(buf: &[u8]) -> Result, ParseError> { + const N: usize = std::mem::size_of::(); + if let Some((len, array)) = buf.split_first_chunk::() { + let len = i32::from_be_bytes(*len); + if len == -1 && array.is_empty() { + Ok(Encoded::Null) + } else if len < 0 { + Err(ParseError::InvalidData) + } else if array.len() < len as _ { + Err(ParseError::TooShort) + } else { + Ok(Encoded::Value(array)) + } + } else { + Err(ParseError::TooShort) + } + } + + pub const fn measure(value: &Encoded) -> usize { + match value { + Encoded::Null => std::mem::size_of::(), + Encoded::Value(value) => value.len() + std::mem::size_of::(), + } + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &Encoded) { + match value { + Encoded::Null => buf.write(&[0xff, 0xff, 0xff, 0xff]), + Encoded::Value(value) => { + let len: i32 = value.len() as _; + buf.write(&len.to_be_bytes()); + buf.write(value); + } + } + } + + pub const fn constant(_constant: usize) -> Encoded<'static> { + panic!("Constants unsupported for this data type") + } +} + +pub struct EncodedMeta {} +impl Meta for EncodedMeta { + fn name(&self) -> &'static str { + "Encoded" + } +} + +impl<'a> Encoded<'a> {} + +impl PartialEq for Encoded<'_> { + fn eq(&self, other: &str) -> bool { + self == &Encoded::Value(other.as_bytes()) + } +} + +impl PartialEq<&str> for Encoded<'_> { + fn eq(&self, other: &&str) -> bool { + self == &Encoded::Value(other.as_bytes()) + } +} + +impl PartialEq<[u8]> for Encoded<'_> { + fn eq(&self, other: &[u8]) -> bool { + self == &Encoded::Value(other) + } +} + +impl PartialEq<&[u8]> for Encoded<'_> { + fn eq(&self, other: &&[u8]) -> bool { + self == &Encoded::Value(other) + } +} + +pub struct Length(pub i32); + +declare_field_access_fixed_size! { + Meta = LengthMeta, + Inflated = usize, + Measure = i32, + Builder = i32, + Size = 4, + Zero = 0, + + pub const fn meta() -> &'static dyn Meta { + &LengthMeta {} + } + + pub const fn extract(buf: &[u8; 4]) -> Result { + let n = i32::from_be_bytes(*buf); + if n >= 0 { + Ok(n as _) + } else { + Err(ParseError::InvalidData) + } + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &i32) { + FieldAccess::::copy_to_buf(buf, value) + } + + pub const fn constant(value: usize) -> usize { + value + } +} + +impl FieldAccess { + pub fn copy_to_buf_rewind(buf: &mut BufWriter, rewind: usize, value: usize) { + buf.write_rewind(rewind, &(value as i32).to_be_bytes()); + } +} + +// We alias usize here. Note that if this causes trouble in the future we can +// probably work around this by adding a new "const value" function to +// FieldAccess. For now it works! +pub struct LengthMeta {} + +impl Meta for LengthMeta { + fn name(&self) -> &'static str { + "len" + } +} + +pub struct BasicMeta { + _phantom: PhantomData, +} + +impl Meta for BasicMeta { + fn name(&self) -> &'static str { + std::any::type_name::() + } +} + +macro_rules! basic_types { + ($($ty:ty),*) => { + $( + declare_field_access_fixed_size! { + Meta = $ty, + Inflated = $ty, + Measure = $ty, + Builder = $ty, + Size = std::mem::size_of::<$ty>(), + Zero = 0, + + pub const fn meta() -> &'static dyn Meta { + &BasicMeta::<$ty> { _phantom: PhantomData } + } + + pub const fn extract(buf: &[u8; std::mem::size_of::<$ty>()]) -> Result<$ty, ParseError> { + Ok(<$ty>::from_be_bytes(*buf)) + } + + pub fn copy_to_buf(buf: &mut BufWriter, value: &$ty) { + buf.write(&<$ty>::to_be_bytes(*value)); + } + + pub const fn constant(value: usize) -> $ty { + value as _ + } + } + )* + }; +} + +basic_types!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); diff --git a/rust/db_proto/src/field_access.rs b/rust/db_proto/src/field_access.rs new file mode 100644 index 00000000000..fc4f96d950e --- /dev/null +++ b/rust/db_proto/src/field_access.rs @@ -0,0 +1,569 @@ +use crate::{BufWriter, Enliven, Meta, ParseError}; + +/// As Rust does not currently support const in traits, we use this struct to +/// provide the const methods. It requires more awkward code, so we make use of +/// macros to generate the code. +/// +/// Note that another consequence is that we have to declare this struct twice: +/// once for this crate, and again when someone tries to instantiate a protocol. +/// The reason for this is that we cannot add additional `impl`s for this `FieldAccess` +/// outside of this crate. Instead, we use a macro to "copy" the existing `impl`s from +/// this crate to the newtype. +pub struct FieldAccess { + _phantom_data: std::marker::PhantomData, +} + +/// Delegates to a concrete [`FieldAccess`] but as a non-const trait. This is +/// used for performing extraction in iterators. +pub trait FieldAccessArray: Enliven { + const META: &'static dyn Meta; + fn size_of_field_at(buf: &[u8]) -> Result; + fn extract(buf: &[u8]) -> Result<::WithLifetime<'_>, ParseError>; + fn copy_to_buf(buf: &mut BufWriter, value: &Self::ForBuilder<'_>); +} + +/// A trait for types which are fixed-size, used to provide a `get` implementation +/// in arrays and iterators. +pub trait FixedSize: Enliven { + const SIZE: usize; + /// Extract this type from the given buffer, assuming that enough bytes are available. + fn extract_infallible(buf: &[u8]) -> ::WithLifetime<'_>; +} + +/// Declares a field access for a given type which is variably-sized. +#[macro_export] +#[doc(hidden)] +macro_rules! declare_field_access { + ( + Meta = $meta:ty, + Inflated = $inflated:ty, + Measure = $measured:ty, + Builder = $builder:ty, + + pub const fn meta() -> &'static dyn Meta + $meta_body:block + + pub const fn size_of_field_at($size_of_arg0:ident : &[u8]) -> Result + $size_of:block + + pub const fn extract($extract_arg0:ident : &[u8]) -> Result<$extract_ret:ty, ParseError> + $extract:block + + pub const fn measure($measure_arg0:ident : &$measure_param:ty) -> usize + $measure:block + + pub fn copy_to_buf($copy_arg0:ident : &mut BufWriter, $copy_arg1:ident : &$value_param:ty) + $copy:block + + pub const fn constant($constant_arg0:ident : usize) -> $constant_ret:ty + $constant:block + ) => { + impl Enliven for $meta { + type WithLifetime<'a> = $inflated; + type ForMeasure<'a> = $measured; + type ForBuilder<'a> = $builder; + } + + impl FieldAccess<$meta> { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + $meta_body + } + #[inline(always)] + pub const fn size_of_field_at($size_of_arg0: &[u8]) -> Result { + $size_of + } + #[inline(always)] + pub const fn extract($extract_arg0: &[u8]) -> Result<$extract_ret, ParseError> { + $extract + } + #[inline(always)] + pub const fn measure($measure_arg0: &$measure_param) -> usize { + $measure + } + #[inline(always)] + pub fn copy_to_buf($copy_arg0: &mut BufWriter, $copy_arg1: &$value_param) { + $copy + } + #[inline(always)] + pub const fn constant($constant_arg0: usize) -> $constant_ret { + $constant + } + } + + $crate::field_access!($crate::FieldAccess, $meta); + $crate::array_access!(variable, $crate::FieldAccess, $meta); + }; +} + +/// Declares a field access for a given type which is fixed-size. Fixed-size +/// fields have simpler extraction logic, and support mapping to Rust arrays. +#[macro_export] +#[doc(hidden)] +macro_rules! declare_field_access_fixed_size { + ( + Meta = $meta:ty, + Inflated = $inflated:ty, + Measure = $measured:ty, + Builder = $builder:ty, + Size = $size:expr, + Zero = $zero:expr, + + pub const fn meta() -> &'static dyn Meta + $meta_body:block + + pub const fn extract($extract_arg0:ident : &$extract_type:ty) -> Result<$extract_ret:ty, ParseError> + $extract:block + + pub fn copy_to_buf($copy_arg0:ident : &mut BufWriter, $copy_arg1:ident : &$value_param:ty) + $copy:block + + pub const fn constant($constant_arg0:ident : usize) -> $constant_ret:ty + $constant:block + ) => { + impl Enliven for $meta { + type WithLifetime<'a> = $inflated; + type ForMeasure<'a> = $measured; + type ForBuilder<'a> = $builder; + } + + impl FieldAccess<$meta> { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + $meta_body + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + if let Ok(_) = Self::extract(buf) { + Ok($size) + } else { + Err(ParseError::TooShort) + } + } + #[inline(always)] + pub const fn extract($extract_arg0: &[u8]) -> Result<$extract_ret, ParseError> { + if let Some(chunk) = $extract_arg0.first_chunk() { + FieldAccess::<$meta>::extract_exact(chunk) + } else { + Err(ParseError::TooShort) + } + } + #[inline(always)] + pub const fn extract_exact( + $extract_arg0: &[u8; $size], + ) -> Result<$extract_ret, ParseError> { + $extract + } + #[inline(always)] + pub const fn measure(_: &$measured) -> usize { + $size + } + #[inline(always)] + pub fn copy_to_buf($copy_arg0: &mut BufWriter, $copy_arg1: &$value_param) { + $copy + } + #[inline(always)] + pub const fn constant($constant_arg0: usize) -> $constant_ret { + $constant + } + } + + impl $crate::FixedSize for $meta { + const SIZE: usize = std::mem::size_of::<$inflated>(); + #[inline(always)] + fn extract_infallible(buf: &[u8]) -> $inflated { + FieldAccess::<$meta>::extract(buf).unwrap() + } + } + + impl Enliven for $crate::meta::FixedArray { + type WithLifetime<'a> = [$inflated; S]; + type ForMeasure<'a> = [$measured; S]; + type ForBuilder<'a> = [$builder; S]; + } + + #[allow(unused)] + impl FieldAccess<$crate::meta::FixedArray> { + #[inline(always)] + pub const fn meta() -> &'static dyn Meta { + &$crate::meta::FixedArray:: { + _phantom: PhantomData, + } + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + let size = $size * S; + if size > buf.len() { + Err($crate::ParseError::TooShort) + } else { + Ok(size) + } + } + #[inline(always)] + pub const fn measure(_: &[$measured; S]) -> usize { + ($size * (S)) + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Result<[$inflated; S], $crate::ParseError> { + let mut out: [$inflated; S] = [const { $zero }; S]; + let mut i = 0; + loop { + if i == S { + break; + } + (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { + match FieldAccess::<$meta>::extract_exact(bytes) { + Ok(value) => (value, rest), + Err(e) => return Err(e), + } + } else { + return Err($crate::ParseError::TooShort); + }; + i += 1; + } + Ok(out) + } + #[inline(always)] + pub fn copy_to_buf(mut buf: &mut BufWriter, value: &[$builder; S]) { + if !buf.test(std::mem::size_of::<$builder>() * S) { + return; + } + for n in value { + FieldAccess::<$meta>::copy_to_buf(buf, n); + } + } + } + + impl FieldAccessArray for $crate::meta::FixedArray { + const META: &'static dyn Meta = FieldAccess::<$meta>::meta(); + #[inline(always)] + fn size_of_field_at(buf: &[u8]) -> Result { + // TODO: needs to verify the values as well + FieldAccess::<$meta>::size_of_field_at(buf).map(|size| size * S) + } + #[inline(always)] + fn extract(mut buf: &[u8]) -> Result<[$inflated; S], ParseError> { + let mut out = [$zero; S]; + for i in 0..S { + (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { + (FieldAccess::<$meta>::extract_exact(bytes)?, rest) + } else { + return Err(ParseError::TooShort); + }; + } + Ok(out) + } + #[inline(always)] + fn copy_to_buf(buf: &mut BufWriter, value: &[$builder; S]) { + for n in value { + FieldAccess::<$meta>::copy_to_buf(buf, n); + } + } + } + + $crate::field_access!($crate::FieldAccess, $meta); + $crate::array_access!(fixed, $crate::FieldAccess, $meta); + }; +} + +/// Delegate to the concrete [`FieldAccess`] for each type we want to extract. +#[macro_export] +#[doc(hidden)] +macro_rules! field_access { + ($acc:ident :: FieldAccess, $ty:ty) => { + impl $crate::FieldAccessArray for $ty { + const META: &'static dyn $crate::Meta = $acc::FieldAccess::<$ty>::meta(); + #[inline(always)] + fn size_of_field_at(buf: &[u8]) -> Result { + $acc::FieldAccess::<$ty>::size_of_field_at(buf) + } + #[inline(always)] + fn extract( + buf: &[u8], + ) -> Result<::WithLifetime<'_>, $crate::ParseError> { + $acc::FieldAccess::<$ty>::extract(buf) + } + #[inline(always)] + fn copy_to_buf( + buf: &mut $crate::BufWriter, + value: &<$ty as $crate::Enliven>::ForBuilder<'_>, + ) { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, value) + } + } + }; +} + +/// Define array accesses for inflated, strongly-typed arrays of both +/// zero-terminated and length-delimited types. +#[macro_export] +#[doc(hidden)] +macro_rules! array_access { + (fixed, $acc:ident :: FieldAccess, $ty:ty) => { + $crate::array_access!(fixed, $acc :: FieldAccess, $ty | u8 i16 u16 i32 u32); + }; + (variable, $acc:ident :: FieldAccess, $ty:ty) => { + $crate::array_access!(variable, $acc :: FieldAccess, $ty | u8 i16 u16 i32 u32); + }; + (fixed, $acc:ident :: FieldAccess, $ty:ty | $($len:ty)*) => { + $( + #[allow(unused)] + impl FieldAccess<$crate::meta::Array<$len, $ty>> { + pub const fn meta() -> &'static dyn Meta { + &$crate::meta::Array::<$len, $ty> { _phantom: PhantomData } + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + const N: usize = <$ty as $crate::FixedSize>::SIZE; + const L: usize = std::mem::size_of::<$len>(); + if let Some(len) = buf.first_chunk::() { + let len_value = <$len>::from_be_bytes(*len); + #[allow(unused_comparisons)] + if len_value < 0 { + return Err($crate::ParseError::InvalidData); + } + let mut byte_len = len_value as usize; + byte_len = match byte_len.checked_mul(N) { + Some(l) => l, + None => return Err($crate::ParseError::TooShort), + }; + byte_len = match byte_len.checked_add(L) { + Some(l) => l, + None => return Err($crate::ParseError::TooShort), + }; + if buf.len() < byte_len { + Err($crate::ParseError::TooShort) + } else { + Ok(byte_len) + } + } else { + Err($crate::ParseError::TooShort) + } + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Result<$crate::Array<$len, $ty>, $crate::ParseError> { + const N: usize = <$ty as $crate::FixedSize>::SIZE; + const L: usize = std::mem::size_of::<$len>(); + if let Some((len, array)) = buf.split_first_chunk::() { + let len_value = <$len>::from_be_bytes(*len); + #[allow(unused_comparisons)] + if len_value < 0 { + return Err($crate::ParseError::InvalidData); + } + let mut byte_len = len_value as usize; + byte_len = match byte_len.checked_mul(N) { + Some(l) => l, + None => return Err($crate::ParseError::TooShort), + }; + byte_len = match byte_len.checked_add(L) { + Some(l) => l, + None => return Err($crate::ParseError::TooShort), + }; + if buf.len() < byte_len { + Err($crate::ParseError::TooShort) + } else { + Ok($crate::Array::new(array, len_value as u32)) + } + } else { + Err($crate::ParseError::TooShort) + } + } + #[inline(always)] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::Enliven>::ForMeasure<'a>]) -> usize { + buffer.len() * std::mem::size_of::<$ty>() + std::mem::size_of::<$len>() + } + #[inline(always)] + pub fn copy_to_buf<'a>(mut buf: &mut BufWriter, value: &'a[<$ty as $crate::Enliven>::ForBuilder<'a>]) { + let size: usize = std::mem::size_of::<$ty>() * value.len() + std::mem::size_of::<$len>(); + if !buf.test(size) { + return; + } + buf.write(&<$len>::to_be_bytes(value.len() as _)); + for n in value { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, n); + } + } + #[inline(always)] + pub const fn constant(value: usize) -> $crate::Array<'static, $len, $ty> { + panic!("Constants unsupported for this data type") + } + } + )* + + #[allow(unused)] + impl $acc::FieldAccess<$crate::meta::ZTArray<$ty>> { + pub const fn meta() -> &'static dyn $crate::Meta { + &$crate::meta::ZTArray::<$ty> { _phantom: std::marker::PhantomData } + } + #[inline] + pub const fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = 1; + loop { + if buf.is_empty() { + return Err($crate::ParseError::TooShort); + } + if buf[0] == 0 { + return Ok(size); + } + let elem_size = match $acc::FieldAccess::<$ty>::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Result<$crate::ZTArray<$ty>, $crate::ParseError> { + Ok($crate::ZTArray::new(buf)) + } + #[inline] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::Enliven>::ForMeasure<'a>]) -> usize { + let mut size = 1; + let mut index = 0; + loop { + if index + 1 > buffer.len() { + break; + } + let item = &buffer[index]; + size += $acc::FieldAccess::<$ty>::measure(item); + index += 1; + } + size + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut $crate::BufWriter, value: &[<$ty as $crate::Enliven>::ForBuilder<'_>]) { + for elem in value { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, elem); + } + buf.write_u8(0); + } + #[inline(always)] + pub const fn constant(value: usize) -> $crate::ZTArray<'static, $ty> { + panic!("Constants unsupported for this data type") + } + } + }; + (variable, $acc:ident :: FieldAccess, $ty:ty | $($len:ty)*) => { + $( + #[allow(unused)] + impl $acc::FieldAccess<$crate::meta::Array<$len, $ty>> { + pub const fn meta() -> &'static dyn $crate::Meta { + &$crate::meta::Array::<$len, $ty> { _phantom: std::marker::PhantomData } + } + #[inline] + pub const fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = std::mem::size_of::<$len>(); + let mut len = match $acc::FieldAccess::<$len>::extract(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + #[allow(unused_comparisons)] + if len < 0 { + return Err($crate::ParseError::InvalidData); + } + buf = buf.split_at(size).1; + loop { + if len <= 0 { + break; + } + len -= 1; + let elem_size = match $acc::FieldAccess::<$ty>::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + Ok(size) + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Result<$crate::Array<'_, $len, $ty>, $crate::ParseError> { + match $acc::FieldAccess::<$len>::extract(buf) { + Ok(len) => Ok($crate::Array::new(buf.split_at(std::mem::size_of::<$len>()).1, len as u32)), + Err(e) => Err(e) + } + } + #[inline] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::Enliven>::ForMeasure<'a>]) -> usize { + let mut size = std::mem::size_of::<$len>(); + let mut index = 0; + loop { + if index + 1 > buffer.len() { + break; + } + let item = &buffer[index]; + size += $acc::FieldAccess::<$ty>::measure(item); + index += 1; + } + size + } + #[inline(always)] + pub fn copy_to_buf<'a>(buf: &mut $crate::BufWriter, value: &'a[<$ty as $crate::Enliven>::ForBuilder<'a>]) { + buf.write(&<$len>::to_be_bytes(value.len() as _)); + for elem in value { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, elem); + } + } + #[inline(always)] + pub const fn constant(value: usize) -> $crate::Array<'static, $len, $ty> { + panic!("Constants unsupported for this data type") + } + } + )* + + #[allow(unused)] + impl $acc::FieldAccess<$crate::meta::ZTArray<$ty>> { + pub const fn meta() -> &'static dyn $crate::Meta { + &$crate::meta::ZTArray::<$ty> { _phantom: std::marker::PhantomData } + } + #[inline] + pub const fn size_of_field_at(mut buf: &[u8]) -> Result { + let mut size = 1; + loop { + if buf.is_empty() { + return Err($crate::ParseError::TooShort); + } + if buf[0] == 0 { + return Ok(size); + } + let elem_size = match $acc::FieldAccess::<$ty>::size_of_field_at(buf) { + Ok(n) => n, + Err(e) => return Err(e), + }; + buf = buf.split_at(elem_size).1; + size += elem_size; + } + } + #[inline(always)] + pub const fn extract(mut buf: &[u8]) -> Result<$crate::ZTArray<$ty>, $crate::ParseError> { + Ok($crate::ZTArray::new(buf)) + } + #[inline] + pub const fn measure<'a>(buffer: &'a[<$ty as $crate::Enliven>::ForMeasure<'a>]) -> usize { + let mut size = 1; + let mut index = 0; + loop { + if index + 1 > buffer.len() { + break; + } + let item = &buffer[index]; + size += $acc::FieldAccess::<$ty>::measure(item); + index += 1; + } + size + } + #[inline(always)] + pub fn copy_to_buf(buf: &mut $crate::BufWriter, value: &[<$ty as $crate::Enliven>::ForBuilder<'_>]) { + for elem in value { + $acc::FieldAccess::<$ty>::copy_to_buf(buf, elem); + } + buf.write_u8(0); + } + #[inline(always)] + pub const fn constant(value: usize) -> $crate::ZTArray<'static, $ty> { + panic!("Constants unsupported for this data type") + } + } + }; +} diff --git a/rust/pgrust/src/protocol/gen.rs b/rust/db_proto/src/gen.rs similarity index 71% rename from rust/pgrust/src/protocol/gen.rs rename to rust/db_proto/src/gen.rs index e0fb066ea65..15e4e0cf6c6 100644 --- a/rust/pgrust/src/protocol/gen.rs +++ b/rust/db_proto/src/gen.rs @@ -10,7 +10,7 @@ /// level of the macro adds its own layer of processing and metadata /// accumulation, eventually leading to the final output. /// -/// The `struct_elaborate!` macro is a tool designed to perform an initial +/// The `$crate::struct_elaborate!` macro is a tool designed to perform an initial /// parsing pass on a Rust `struct`, enriching it with metadata to facilitate /// further macro processing. It begins by extracting and analyzing the fields /// of the `struct`, capturing associated metadata such as attributes and types. @@ -35,6 +35,8 @@ /// it reconstructs an enriched `struct`-like data blob using the accumulated /// metadata. It then passes this enriched `struct` to the `next` macro for /// further processing. +#[doc(hidden)] +#[macro_export] macro_rules! struct_elaborate { ( $next:ident $( ($($next_args:tt)*) )? => @@ -50,7 +52,7 @@ macro_rules! struct_elaborate { ) => { // paste! is necessary here because it allows us to re-interpret a "ty" // as an explicit type pattern below. - struct_elaborate!(__builder_type__ + $crate::struct_elaborate!(__builder_type__ // Pass down a "fixed offset" flag that indicates whether the // current field is at a fixed offset. This gets reset to // `no_fixed_offset` when we hit a variable-sized field. @@ -74,66 +76,66 @@ macro_rules! struct_elaborate { // End of push-down automation - jumps to `__finalize__` (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields() accum($($faccum:tt)*) original($($original:tt)*)) => { - struct_elaborate!(__finalize__ accum($($faccum)*) original($($original)*)); + $crate::struct_elaborate!(__finalize__ accum($($faccum)*) original($($original)*)); }; // Skip __builder_value__ for 'len' (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(len)(len), value(), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::protocol::meta::Length), size(fixed=fixed), value(auto=auto), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::meta::Length), size(fixed=fixed), value(auto=auto), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(len)(len), value($($value:tt)+), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::protocol::meta::Length), size(fixed=fixed), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+4)) fields([type($crate::meta::Length), size(fixed=fixed), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); }; // Pattern match on known fixed-sized types and mark them as `size(fixed=fixed)` - (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type([u8; $len:literal])($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type([$elem:ty; $len:literal])($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($crate::meta::FixedArray<$len, $elem>), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u8)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr)fields([type(i16)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(i32)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u32)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(u64)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type(Uuid)($ty:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>$fixed $fixed_expr=>($fixed_expr+std::mem::size_of::<$ty>())) fields([type($ty), size(fixed=fixed), $($rest)*] $($frest)*) $($srest)*); }; // Fallback for other types - variable sized (__builder_type__ fixed($fixed:ident $fixed_expr:expr) fields([type($ty:ty)($ty2:ty), $($rest:tt)*] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_value__ fixed($fixed=>no_fixed_offset $fixed_expr=>(0)) fields([type($ty), size(variable=variable), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_value__ fixed($fixed=>no_fixed_offset $fixed_expr=>(0)) fields([type($ty), size(variable=variable), $($rest)*] $($frest)*) $($srest)*); }; // Next, mark the presence or absence of a value (__builder_value__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value(), $($rest:tt)* ] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(no_value=no_value), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(no_value=no_value), $($rest)*] $($frest)*) $($srest)*); }; (__builder_value__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value($($value:tt)+), $($rest:tt)* ] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder_docs__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value(value=($($value)*)), $($rest)*] $($frest)*) $($srest)*); }; // Next, handle missing docs (__builder_docs__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs(), name($field:ident), $($rest:tt)* ] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!("`", stringify!($field), "` field.")), name($field), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!("`", stringify!($field), "` field.")), name($field), $($rest)*] $($frest)*) $($srest)*); }; (__builder_docs__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs($($fdoc:literal)+), $($rest:tt)* ] $($frest:tt)*) $($srest:tt)*) => { - struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!($($fdoc)+)), $($rest)*] $($frest)*) $($srest)*); + $crate::struct_elaborate!(__builder__ fixed($fixed=>$fixed_new $fixed_expr=>$fixed_expr_new) fields([type($ty), size($($size)*), value($($value)*), docs(concat!($($fdoc)+)), $($rest)*] $($frest)*) $($srest)*); }; @@ -141,7 +143,7 @@ macro_rules! struct_elaborate { (__builder__ fixed($fixed:ident=>$fixed_new:ident $fixed_expr:expr=>$fixed_expr_new:expr) fields([ type($ty:ty), size($($size:tt)*), value($($value:tt)*), docs($fdoc:expr), name($field:ident), $($rest:tt)* ] $($frest:tt)*) accum($($faccum:tt)*) original($($original:tt)*)) => { - struct_elaborate!(__builder_type__ fixed($fixed_new $fixed_expr_new) fields($($frest)*) accum( + $crate::struct_elaborate!(__builder_type__ fixed($fixed_new $fixed_expr_new) fields($($frest)*) accum( $($faccum)* { name($field), @@ -169,19 +171,50 @@ macro_rules! struct_elaborate { } } -macro_rules! protocol { +/// Generates a protocol definition from a Rust-like DSL. +/// +/// ``` +/// struct Foo { +/// bar: u8, +/// baz: u16, +/// } +/// ``` +#[doc(hidden)] +#[macro_export] +macro_rules! __protocol { ($( $( #[ $sdoc:meta ] )* struct $name:ident $(: $super:ident)? { $($struct:tt)+ } )+) => { + mod access { + #![allow(unused)] + + /// This struct is specialized for each type we want to extract data from. We + /// have to do it this way to work around Rust's lack of const specialization. + pub struct FieldAccess { + _phantom_data: std::marker::PhantomData, + } + + $crate::field_access_copy!{basic $crate::FieldAccess, self::FieldAccess, + i8, u8, i16, u16, i32, u32, i64, u64, i128, u128, + $crate::meta::Uuid + } + $crate::field_access_copy!{$crate::FieldAccess, self::FieldAccess, + $crate::meta::ZTString, + $crate::meta::LString, + $crate::meta::Rest, + $crate::meta::Encoded, + $crate::meta::Length + } + } + $( - paste::paste!( + $crate::paste!( #[allow(unused_imports)] pub(crate) mod [<__ $name:lower>] { + use $crate::{meta::*, protocol_builder}; use super::meta::*; - use $crate::protocol::meta::*; - use $crate::protocol::gen::*; - struct_elaborate!(protocol_builder(__struct__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); - struct_elaborate!(protocol_builder(__meta__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); - struct_elaborate!(protocol_builder(__measure__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); - struct_elaborate!(protocol_builder(__builder__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + $crate::struct_elaborate!(protocol_builder(__struct__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + $crate::struct_elaborate!(protocol_builder(__meta__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + $crate::struct_elaborate!(protocol_builder(__measure__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); + $crate::struct_elaborate!(protocol_builder(__builder__) => $( #[ $sdoc ] )* struct $name $(: $super)? { $($struct)+ } ); } ); )+ @@ -189,7 +222,7 @@ macro_rules! protocol { pub mod data { #![allow(unused_imports)] $( - paste::paste!( + $crate::paste!( pub use super::[<__ $name:lower>]::$name; ); )+ @@ -197,7 +230,7 @@ macro_rules! protocol { pub mod meta { #![allow(unused_imports)] $( - paste::paste!( + $crate::paste!( pub use super::[<__ $name:lower>]::[<$name Meta>] as $name; ); )+ @@ -205,7 +238,7 @@ macro_rules! protocol { /// A slice containing the metadata references for all structs in /// this definition. #[allow(unused)] - pub const ALL: &'static [&'static dyn $crate::protocol::Meta] = &[ + pub const ALL: &'static [&'static dyn $crate::Meta] = &[ $( &$name {} ),* @@ -214,7 +247,7 @@ macro_rules! protocol { pub mod builder { #![allow(unused_imports)] $( - paste::paste!( + $crate::paste!( pub use super::[<__ $name:lower>]::[<$name Builder>] as $name; ); )+ @@ -222,7 +255,7 @@ macro_rules! protocol { pub mod measure { #![allow(unused_imports)] $( - paste::paste!( + $crate::paste!( pub use super::[<__ $name:lower>]::[<$name Measure>] as $name; ); )+ @@ -230,6 +263,11 @@ macro_rules! protocol { }; } +#[doc(inline)] +pub use __protocol as protocol; + +#[macro_export] +#[doc(hidden)] macro_rules! r#if { (__is_empty__ [] {$($true:tt)*} else {$($false:tt)*}) => { $($true)* @@ -244,6 +282,8 @@ macro_rules! r#if { }; } +#[doc(hidden)] +#[macro_export] macro_rules! protocol_builder { (__struct__, struct $name:ident { super($($super:ident)?), @@ -258,7 +298,7 @@ macro_rules! protocol_builder { $($rest:tt)* },)*), }) => { - paste::paste!( + $crate::paste!( /// Our struct we are building. type S<'a> = $name<'a>; /// The meta-struct for the struct we are building. @@ -321,23 +361,23 @@ macro_rules! protocol_builder { $( $( - let Ok(val) = $crate::protocol::FieldAccess::<$type>::extract(buf.split_at(offset).1) else { + let Ok(val) = super::access::FieldAccess::<$type>::extract(buf.split_at(offset).1) else { return false; }; if val as usize != $value as usize { return false; } )? - offset += std::mem::size_of::<$type>(); + offset += std::mem::size_of::<<$type as $crate::Enliven>::ForBuilder<'static>>(); )* true } $( - pub const fn can_cast(parent: &<$super as $crate::protocol::Enliven>::WithLifetime<'a>) -> bool { + pub const fn can_cast(parent: &::WithLifetime<'a>) -> bool { Self::is_buffer(parent.__buf) } - pub const fn try_new(parent: &<$super as $crate::protocol::Enliven>::WithLifetime<'a>) -> Option { + pub const fn try_new(parent: &::WithLifetime<'a>) -> Option { if Self::can_cast(parent) { // TODO let Ok(value) = Self::new(parent.__buf) else { @@ -352,13 +392,13 @@ macro_rules! protocol_builder { /// Creates a new instance of this struct from a given buffer. #[inline] - pub const fn new(mut buf: &'a [u8]) -> Result { + pub const fn new(mut buf: &'a [u8]) -> Result { let mut __field_offsets = [0; Meta::FIELD_COUNT + 1]; let mut offset = 0; let mut index = 0; $( __field_offsets[index] = offset; - offset += match $crate::protocol::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1) { + offset += match super::access::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1) { Ok(n) => n, Err(e) => return Err(e), }; @@ -380,14 +420,14 @@ macro_rules! protocol_builder { #[doc = $fdoc] #[allow(unused)] #[inline] - pub const fn $field<'s>(&'s self) -> <$type as $crate::protocol::Enliven>::WithLifetime<'a> where 's : 'a { + pub const fn $field<'s>(&'s self) -> <$type as $crate::Enliven>::WithLifetime<'a> where 's : 'a { // Perform a const buffer extraction operation let offset1 = self.__field_offsets[F::$field as usize]; let offset2 = self.__field_offsets[F::$field as usize + 1]; let (_, buf) = self.__buf.split_at(offset1); let (buf, _) = buf.split_at(offset2 - offset1); // This will not panic: we've confirmed the validity of the buffer when sizing - let Ok(value) = $crate::protocol::FieldAccess::<$type>::extract(buf) else { + let Ok(value) = super::access::FieldAccess::<$type>::extract(buf) else { panic!(); }; value @@ -410,7 +450,7 @@ macro_rules! protocol_builder { $($rest:tt)* },)*), }) => { - paste::paste!( + $crate::paste!( $( #[$sdoc] )? #[allow(unused)] #[derive(Debug, Default)] @@ -430,25 +470,25 @@ macro_rules! protocol_builder { #[allow(unused)] impl Meta { pub const FIELD_COUNT: usize = [$(stringify!($field)),*].len(); - $($(pub const [<$field:upper _VALUE>]: $type = $crate::protocol::FieldAccess::<$type>::constant($value as usize);)?)* + $($(pub const [<$field:upper _VALUE>]: <$type as $crate::Enliven>::WithLifetime<'static> = super::access::FieldAccess::<$type>::constant($value as usize);)?)* } - impl $crate::protocol::Meta for Meta { + impl $crate::Meta for Meta { fn name(&self) -> &'static str { stringify!($name) } - fn relations(&self) -> &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] { - r#if!(__is_empty__ [$($super)?] { - const RELATIONS: &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] = &[ + fn relations(&self) -> &'static [($crate::MetaRelation, &'static dyn $crate::Meta)] { + $crate::r#if!(__is_empty__ [$($super)?] { + const RELATIONS: &'static [($crate::MetaRelation, &'static dyn $crate::Meta)] = &[ $( - ($crate::protocol::MetaRelation::Field(stringify!($field)), $crate::protocol::FieldAccess::<$type>::meta()) + ($crate::MetaRelation::Field(stringify!($field)), super::access::FieldAccess::<$type>::meta()) ),* ]; } else { - const RELATIONS: &'static [($crate::protocol::MetaRelation, &'static dyn $crate::protocol::Meta)] = &[ - ($crate::protocol::MetaRelation::Parent, $crate::protocol::FieldAccess::<$($super)?>::meta()), + const RELATIONS: &'static [($crate::MetaRelation, &'static dyn $crate::Meta)] = &[ + ($crate::MetaRelation::Parent, super::access::FieldAccess::::meta()), $( - ($crate::protocol::MetaRelation::Field(stringify!($field)), $crate::protocol::FieldAccess::<$type>::meta()) + ($crate::MetaRelation::Field(stringify!($field)), super::access::FieldAccess::<$type>::meta()) ),* ]; }); @@ -460,9 +500,9 @@ macro_rules! protocol_builder { protocol_builder!(__meta__, $fixed($fixed_expr) $field $type); )* - impl $crate::protocol::StructMeta for Meta { + impl $crate::StructMeta for Meta { type Struct<'a> = S<'a>; - fn new(buf: &[u8]) -> Result, $crate::protocol::ParseError> { + fn new(buf: &[u8]) -> Result, $crate::ParseError> { S::new(buf) } fn to_vec(s: &Self::Struct<'_>) -> Vec { @@ -470,27 +510,27 @@ macro_rules! protocol_builder { } } - impl $crate::protocol::Enliven for Meta { + impl $crate::Enliven for Meta { type WithLifetime<'a> = S<'a>; type ForMeasure<'a> = M<'a>; type ForBuilder<'a> = B<'a>; } #[allow(unused)] - impl $crate::protocol::FieldAccess { + impl super::access::FieldAccess { #[inline(always)] pub const fn name() -> &'static str { stringify!($name) } #[inline(always)] - pub const fn meta() -> &'static dyn $crate::protocol::Meta { + pub const fn meta() -> &'static dyn $crate::Meta { &Meta {} } #[inline] - pub const fn size_of_field_at(buf: &[u8]) -> Result { + pub const fn size_of_field_at(buf: &[u8]) -> Result { let mut offset = 0; $( - offset += match $crate::protocol::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1) { + offset += match super::access::FieldAccess::<$type>::size_of_field_at(buf.split_at(offset).1) { Ok(n) => n, Err(e) => return Err(e), }; @@ -498,7 +538,7 @@ macro_rules! protocol_builder { Ok(offset) } #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result<$name<'_>, $crate::protocol::ParseError> { + pub const fn extract(buf: &[u8]) -> Result<$name<'_>, $crate::ParseError> { $name::new(buf) } #[inline(always)] @@ -506,21 +546,18 @@ macro_rules! protocol_builder { measure.measure() } #[inline(always)] - pub fn copy_to_buf(buf: &mut $crate::protocol::writer::BufWriter, builder: &B) { - builder.copy_to_buf(buf) - } - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut $crate::protocol::writer::BufWriter, builder: &B) { + pub fn copy_to_buf(buf: &mut $crate::BufWriter, builder: &B) { builder.copy_to_buf(buf) } } - $crate::protocol::field_access!{[<$name Meta>]} - $crate::protocol::arrays::array_access!{[<$name Meta>]} + use super::access::FieldAccess as FieldAccess; + $crate::field_access!{self::FieldAccess, [<$name Meta>]} + $crate::array_access!{variable, self::FieldAccess, [<$name Meta>]} ); }; - (__meta__, fixed_offset($fixed_expr:expr) $field:ident $crate::protocol::meta::Length) => { - impl $crate::protocol::StructLength for Meta { + (__meta__, fixed_offset($fixed_expr:expr) $field:ident $crate::meta::Length) => { + impl $crate::StructLength for Meta { fn length_field_of(of: &Self::Struct<'_>) -> usize { of.$field() } @@ -529,7 +566,7 @@ macro_rules! protocol_builder { } } }; - (__meta__, $fixed:ident($fixed_expr:expr) $field:ident $crate::protocol::meta::Rest) => { + (__meta__, $fixed:ident($fixed_expr:expr) $field:ident $crate::meta::Rest) => { }; (__meta__, $fixed:ident($fixed_expr:expr) $field:ident $any:ty) => { @@ -547,8 +584,8 @@ macro_rules! protocol_builder { $($rest:tt)* },)*), }) => { - paste::paste!( - r#if!(__is_empty__ [$($($variable_marker)?)*] { + $crate::paste!( + $crate::r#if!(__is_empty__ [$($($variable_marker)?)*] { $( #[$sdoc] )? // No variable-sized fields #[derive(Default, Eq, PartialEq)] @@ -564,7 +601,7 @@ macro_rules! protocol_builder { // pattern. $($( #[doc = $fdoc] - pub $field: r#if!(__has__ [$variable_marker] {<$type as $crate::protocol::Enliven>::ForMeasure<'a>}), + pub $field: $crate::r#if!(__has__ [$variable_marker] {<$type as $crate::Enliven>::ForMeasure<'a>}), )?)* } }); @@ -573,8 +610,8 @@ macro_rules! protocol_builder { pub const fn measure(&self) -> usize { let mut size = 0; $( - r#if!(__has__ [$($variable_marker)?] { size += $crate::protocol::FieldAccess::<$type>::measure(&self.$field); }); - r#if!(__has__ [$($fixed_marker)?] { size += std::mem::size_of::<$type>(); }); + $crate::r#if!(__has__ [$($variable_marker)?] { size += super::access::FieldAccess::<$type>::measure(&self.$field); }); + $crate::r#if!(__has__ [$($fixed_marker)?] { size += std::mem::size_of::<<$type as $crate::Enliven>::ForBuilder<'static>>(); }); )* size } @@ -594,8 +631,8 @@ macro_rules! protocol_builder { $($rest:tt)* },)*), }) => { - paste::paste!( - r#if!(__is_empty__ [$($($no_value)?)*] { + $crate::paste!( + $crate::r#if!(__is_empty__ [$($($no_value)?)*] { $( #[$sdoc] )? // No unfixed-value fields #[derive(::derive_more::Debug, Default, Eq, PartialEq)] @@ -612,30 +649,30 @@ macro_rules! protocol_builder { // somehow use $no_value in the remainder of the pattern. $($( #[doc = $fdoc] - pub $field: r#if!(__has__ [$no_value] {<$type as $crate::protocol::Enliven>::ForBuilder<'a>}), + pub $field: $crate::r#if!(__has__ [$no_value] {<$type as $crate::Enliven>::ForBuilder<'a>}), )?)* } }); impl B<'_> { #[allow(unused)] - pub fn copy_to_buf(&self, buf: &mut $crate::protocol::writer::BufWriter) { + pub fn copy_to_buf(&self, buf: &mut $crate::BufWriter) { $( - r#if!(__is_empty__ [$($value)?] { - r#if!(__is_empty__ [$($auto)?] { - $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, self.$field); + $crate::r#if!(__is_empty__ [$($value)?] { + $crate::r#if!(__is_empty__ [$($auto)?] { + <$type as $crate::FieldAccessArray>::copy_to_buf(buf, &self.$field); } else { let auto_offset = buf.size(); - $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, 0); + <$type as $crate::FieldAccessArray>::copy_to_buf(buf, &0); }); } else { - $crate::protocol::FieldAccess::<$type>::copy_to_buf(buf, $($value)? as usize as _); + <$type as $crate::FieldAccessArray>::copy_to_buf(buf, &($($value)? as usize as _)); }); )* $( - r#if!(__has__ [$($auto)?] { - $crate::protocol::FieldAccess::::copy_to_buf_rewind(buf, auto_offset, buf.size() - auto_offset); + $crate::r#if!(__has__ [$($auto)?] { + $crate::FieldAccess::<$crate::meta::Length>::copy_to_buf_rewind(buf, auto_offset, buf.size() - auto_offset); }); )* @@ -646,7 +683,7 @@ macro_rules! protocol_builder { #[allow(unused)] pub fn to_vec(&self) -> Vec { let mut vec = Vec::with_capacity(256); - let mut buf = $crate::protocol::writer::BufWriter::new(&mut vec); + let mut buf = $crate::BufWriter::new(&mut vec); self.copy_to_buf(&mut buf); match buf.finish() { Ok(size) => { @@ -655,7 +692,7 @@ macro_rules! protocol_builder { }, Err(size) => { vec.resize(size, 0); - let mut buf = $crate::protocol::writer::BufWriter::new(&mut vec); + let mut buf = $crate::BufWriter::new(&mut vec); self.copy_to_buf(&mut buf); // Will not fail this second time let size = buf.finish().unwrap(); @@ -669,14 +706,12 @@ macro_rules! protocol_builder { }; } -pub(crate) use {protocol, protocol_builder, r#if, struct_elaborate}; - #[cfg(test)] mod tests { use pretty_assertions::assert_eq; mod fixed_only { - protocol!( + crate::protocol!( struct FixedOnly { a: u8, } @@ -684,20 +719,20 @@ mod tests { } mod fixed_only_value { - protocol!(struct FixedOnlyValue { + crate::protocol!(struct FixedOnlyValue { a: u8 = 1, }); } mod mixed { - protocol!(struct Mixed { + crate::protocol!(struct Mixed { a: u8 = 1, s: ZTString, }); } mod docs { - protocol!( + crate::protocol!( /// Docs struct Docs { /// Docs @@ -709,7 +744,7 @@ mod tests { } mod length { - protocol!( + crate::protocol!( struct WithLength { a: u8, l: len, @@ -718,7 +753,7 @@ mod tests { } mod array { - protocol!( + crate::protocol!( struct StaticArray { a: u8, l: [u8; 4], @@ -727,7 +762,7 @@ mod tests { } mod string { - protocol!( + crate::protocol!( struct HasLString { s: LString, } @@ -736,7 +771,7 @@ mod tests { macro_rules! assert_stringify { (($($struct:tt)*), ($($expected:tt)*)) => { - struct_elaborate!(assert_stringify(__internal__ ($($expected)*)) => $($struct)*); + $crate::struct_elaborate!(assert_stringify(__internal__ ($($expected)*)) => $($struct)*); }; (__internal__ ($($expected:tt)*), $($struct:tt)*) => { // We don't want whitespace to impact this comparison @@ -792,7 +827,7 @@ mod tests { fixed(fixed_offset = fixed_offset, (0)), }, { - name(l), type (crate::protocol::meta::Length), size(fixed = fixed), + name(l), type (crate::meta::Length), size(fixed = fixed), value(auto = auto), docs(concat!("`", stringify! (l), "` field.")), fixed(fixed_offset = fixed_offset, ((0) + std::mem::size_of::())), }, @@ -808,7 +843,7 @@ mod tests { fixed(no_fixed_offset = no_fixed_offset, (0)), }, { - name(d), type ([u8; 4]), size(fixed = fixed), + name(d), type (crate::meta::FixedArray<4, u8>), size(fixed = fixed), value(no_value = no_value), docs(concat!("`", stringify! (d), "` field.")), fixed(no_fixed_offset = no_fixed_offset, ((0) + std::mem::size_of::())), @@ -818,7 +853,8 @@ mod tests { value(no_value = no_value), docs(concat!("`", stringify! (e), "` field.")), fixed(no_fixed_offset = no_fixed_offset, - (((0) + std::mem::size_of::()) + std::mem::size_of::<[u8; 4]>())), + (((0) + std::mem::size_of::()) + + std::mem::size_of::<[u8; 4]>())), }, ), })); diff --git a/rust/db_proto/src/lib.rs b/rust/db_proto/src/lib.rs new file mode 100644 index 00000000000..537ed319568 --- /dev/null +++ b/rust/db_proto/src/lib.rs @@ -0,0 +1,218 @@ +mod arrays; +mod buffer; +mod datatypes; +mod field_access; +mod gen; +mod message_group; +mod writer; + +#[doc(hidden)] +pub mod test_protocol; + +/// Metatypes for the protocol and related arrays/strings. +pub mod meta { + pub use super::arrays::meta::*; + pub use super::datatypes::meta::*; +} + +#[allow(unused)] +pub use arrays::{Array, ArrayIter, ZTArray, ZTArrayIter}; +pub use buffer::StructBuffer; +#[allow(unused)] +pub use datatypes::{Encoded, LString, Length, Rest, Uuid, ZTString}; +pub use field_access::{FieldAccess, FieldAccessArray, FixedSize}; +pub use writer::BufWriter; + +#[doc(inline)] +pub use gen::protocol; +#[doc(inline)] +pub use message_group::{match_message, message_group}; + +/// Re-export for the `protocol!` macro. +#[doc(hidden)] +pub use paste::paste; + +#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq, Eq)] +pub enum ParseError { + #[error("Buffer is too short")] + TooShort, + #[error("Invalid data")] + InvalidData, +} + +/// Implemented for all structs. +pub trait StructMeta { + type Struct<'a>: std::fmt::Debug; + fn new(buf: &[u8]) -> Result, ParseError>; + fn to_vec(s: &Self::Struct<'_>) -> Vec; +} + +/// Implemented for all generated structs that have a [`meta::Length`] field at a fixed offset. +pub trait StructLength: StructMeta { + fn length_field_of(of: &Self::Struct<'_>) -> usize; + fn length_field_offset() -> usize; + fn length_of_buf(buf: &[u8]) -> Option { + if buf.len() < Self::length_field_offset() + std::mem::size_of::() { + None + } else { + let len = FieldAccess::::extract( + &buf[Self::length_field_offset() + ..Self::length_field_offset() + std::mem::size_of::()], + ) + .ok()?; + Some(Self::length_field_offset() + len) + } + } +} + +/// For a given metaclass, returns the inflated type, a measurement type and a +/// builder type. +/// +/// Types that don't include a lifetime can use the same type for the meta type +/// and the `WithLifetime` type. +pub trait Enliven { + type WithLifetime<'a>; + type ForMeasure<'a>: 'a; + type ForBuilder<'a>: 'a; +} + +#[derive(Debug, Eq, PartialEq)] +pub enum MetaRelation { + Parent, + Length, + Item, + Field(&'static str), +} + +pub trait Meta { + fn name(&self) -> &'static str { + std::any::type_name::() + } + fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { + &[] + } + fn fixed_length(&self) -> Option { + None + } + fn field(&self, name: &'static str) -> Option<&'static dyn Meta> { + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Field(name) { + return Some(*meta); + } + } + None + } + fn parent(&self) -> Option<&'static dyn Meta> { + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Parent { + return Some(*meta); + } + } + None + } +} + +impl PartialEq for dyn Meta { + fn eq(&self, other: &T) -> bool { + other.name() == self.name() + } +} + +impl std::fmt::Debug for dyn Meta { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut s = f.debug_struct(self.name()); + if let Some(length) = self.fixed_length() { + s.field("Length", &length); + } + for (relation, meta) in self.relations() { + if relation == &MetaRelation::Parent { + s.field(&format!("{relation:?}"), &meta.name()); + } else { + s.field(&format!("{relation:?}"), meta); + } + } + s.finish() + } +} + +/// Used internally by the `protocol!` macro to copy from `FieldAccess` in this crate to +/// `FieldAccess` in the generated code. +#[macro_export] +#[doc(hidden)] +macro_rules! field_access_copy { + ($acc1:ident :: FieldAccess, $acc2:ident :: FieldAccess, $($ty:ty),*) => { + $( + $crate::field_access_copy!(: $acc1 :: FieldAccess, $acc2 :: FieldAccess, + $ty, + $crate::meta::ZTArray<$ty>, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array + ); + )* + }; + + (basic $acc1:ident :: FieldAccess, $acc2:ident :: FieldAccess, $($ty:ty),*) => { + $( + + $crate::field_access_copy!(: $acc1 :: FieldAccess, $acc2 :: FieldAccess, + $ty, + $crate::meta::ZTArray<$ty>, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array, + $crate::meta::Array + ); + + impl $acc2 :: FieldAccess<$crate::meta::FixedArray> { + #[inline(always)] + pub const fn meta() -> &'static dyn $crate::Meta { + $acc1::FieldAccess::<$crate::meta::FixedArray>::meta() + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + $acc1::FieldAccess::<$crate::meta::FixedArray>::size_of_field_at(buf) + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Result<[<$ty as $crate::Enliven>::WithLifetime<'_>; S], $crate::ParseError> { + $acc1::FieldAccess::<$crate::meta::FixedArray>::extract(buf) + } + pub const fn constant(_: usize) -> $ty { + panic!("Constants unsupported for this data type") + } + #[inline(always)] + pub const fn measure(value: &[<$ty as $crate::Enliven>::ForMeasure<'_>; S]) -> usize { + $acc1::FieldAccess::<$crate::meta::FixedArray>::measure(value) + } + } + )* + }; + (: $acc1:ident :: FieldAccess, $acc2:ident :: FieldAccess, $($ty:ty),*) => { + $( + impl $acc2 :: FieldAccess<$ty> { + #[inline(always)] + pub const fn meta() -> &'static dyn $crate::Meta { + $acc1::FieldAccess::<$ty>::meta() + } + #[inline(always)] + pub const fn size_of_field_at(buf: &[u8]) -> Result { + $acc1::FieldAccess::<$ty>::size_of_field_at(buf) + } + #[inline(always)] + pub const fn extract(buf: &[u8]) -> Result<<$ty as $crate::Enliven>::WithLifetime<'_>, $crate::ParseError> { + $acc1::FieldAccess::<$ty>::extract(buf) + } + pub const fn constant(value: usize) -> <$ty as $crate::Enliven>::WithLifetime<'static> { + $acc1::FieldAccess::<$ty>::constant(value) + } + #[inline(always)] + pub const fn measure(value: &<$ty as $crate::Enliven>::ForMeasure<'_>) -> usize { + $acc1::FieldAccess::<$ty>::measure(value) + } + } + )* + }; +} diff --git a/rust/pgrust/src/protocol/message_group.rs b/rust/db_proto/src/message_group.rs similarity index 85% rename from rust/pgrust/src/protocol/message_group.rs rename to rust/db_proto/src/message_group.rs index 7201864dc87..a88fade9d59 100644 --- a/rust/pgrust/src/protocol/message_group.rs +++ b/rust/db_proto/src/message_group.rs @@ -1,6 +1,8 @@ -macro_rules! message_group { +#[doc(hidden)] +#[macro_export] +macro_rules! __message_group { ($(#[$doc:meta])* $group:ident : $super:ident = [$($message:ident),*]) => { - paste::paste!( + $crate::paste!( $(#[$doc])* #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[allow(unused)] @@ -29,7 +31,7 @@ macro_rules! message_group { } } - pub fn copy_to_buf(&self, writer: &mut $crate::protocol::writer::BufWriter) { + pub fn copy_to_buf(&self, writer: &mut $crate::BufWriter) { match self { $( Self::$message(message) => message.copy_to_buf(writer), @@ -65,7 +67,7 @@ macro_rules! message_group { impl $group { pub fn identify(buf: &[u8]) -> Option { $( - if ::WithLifetime::is_buffer(buf) { + if ::WithLifetime::is_buffer(buf) { return Some(Self::$message); } )* @@ -76,17 +78,19 @@ macro_rules! message_group { ); }; } -pub(crate) use message_group; + +#[doc(inline)] +pub use __message_group as message_group; /// Perform a match on a message. /// /// ```rust -/// use pgrust::protocol::*; -/// use pgrust::protocol::postgres::data::*; +/// use db_proto::*; +/// use db_proto::test_protocol::data::*; /// /// let buf = [b'?', 0, 0, 0, 4]; /// match_message!(Message::new(&buf), Backend { -/// (BackendKeyData as data) => { +/// (DataRow as data) => { /// todo!(); /// }, /// unknown => { @@ -102,7 +106,7 @@ macro_rules! __match_message { $unknown:ident => $unknown_impl:block $(,)? }) => { 'block: { - let __message: Result<_, $crate::protocol::ParseError> = $buf; + let __message: Result<_, $crate::ParseError> = $buf; let res = match __message { Ok(__message) => { $( @@ -138,18 +142,15 @@ pub use __match_message as match_message; #[cfg(test)] mod tests { use super::*; - use crate::protocol::postgres::{ - builder, - data::{Message, PasswordMessage}, - }; + use crate::test_protocol::{builder, data::*}; #[test] fn test_match() { let message = builder::Sync::default().to_vec(); let message = Message::new(&message); match_message!(message, Message { - (PasswordMessage as password) => { - eprintln!("{password:?}"); + (DataRow as data_row) => { + eprintln!("{data_row:?}"); return; }, unknown => { diff --git a/rust/db_proto/src/test_protocol.rs b/rust/db_proto/src/test_protocol.rs new file mode 100644 index 00000000000..344eaeed63a --- /dev/null +++ b/rust/db_proto/src/test_protocol.rs @@ -0,0 +1,140 @@ +//! A pseudo-Postgres protocol for testing. +use crate::gen::protocol; + +protocol!( + struct Message { + /// The message type. + mtype: u8, + /// The length of the message contents in bytes, including self. + mlen: len, + /// The message contents. + data: Rest, + } + + /// The `CommandComplete` struct represents a message indicating the successful completion of a command. + struct CommandComplete: Message { + /// Identifies the message as a command-completed response. + mtype: u8 = 'C', + /// Length of message contents in bytes, including self. + mlen: len, + /// The command tag. + tag: ZTString, + } + + /// The `Sync` message is used to synchronize the client and server. + struct Sync: Message { + /// Identifies the message as a synchronization request. + mtype: u8 = 'S', + /// Length of message contents in bytes, including self. + mlen: len, + } + + /// The `DataRow` message represents a row of data returned from a query. + struct DataRow: Message { + /// Identifies the message as a data row. + mtype: u8 = 'D', + /// Length of message contents in bytes, including self. + mlen: len, + /// The values in the row. + values: Array, + } + + struct QueryType { + /// The type of the query parameter. + typ: u8, + /// The length of the query parameter. + len: u32, + /// The metadata of the query parameter. + meta: Array, + } + + struct Query: Message { + /// Identifies the message as a query. + mtype: u8 = 'Q', + /// Length of message contents in bytes, including self. + mlen: len, + /// The query string. + query: ZTString, + /// The types of the query parameters. + types: Array, + } + + struct Key { + /// The key. + key: [u8; 16], + } + + struct Uuids { + /// The UUIDs. + uuids: Array, + } +); + +#[cfg(test)] +mod tests { + use uuid::Uuid; + + use super::*; + + #[test] + fn test_meta() { + let expected = [ + r#"Message { Field("mtype"): u8, Field("mlen"): len, Field("data"): Rest }"#, + r#"CommandComplete { Parent: "Message", Field("mtype"): u8, Field("mlen"): len, Field("tag"): ZTString }"#, + r#"Sync { Parent: "Message", Field("mtype"): u8, Field("mlen"): len }"#, + r#"DataRow { Parent: "Message", Field("mtype"): u8, Field("mlen"): len, Field("values"): Array { Length: i16, Item: Encoded } }"#, + r#"QueryType { Field("typ"): u8, Field("len"): u32, Field("meta"): Array { Length: u32, Item: u8 } }"#, + r#"Query { Parent: "Message", Field("mtype"): u8, Field("mlen"): len, Field("query"): ZTString, Field("types"): Array { Length: i16, Item: QueryType { Field("typ"): u8, Field("len"): u32, Field("meta"): Array { Length: u32, Item: u8 } } } }"#, + r#"Key { Field("key"): FixedArray { Length: 16, Item: u8 } }"#, + r#"Uuids { Field("uuids"): Array { Length: u32, Item: Uuid } }"#, + ]; + + for (i, meta) in meta::ALL.iter().enumerate() { + assert_eq!(expected[i], format!("{meta:?}")); + } + } + + #[test] + fn test_query() { + let buf = builder::Query { + query: "SELECT * from foo", + types: &[builder::QueryType { + typ: 1, + len: 4, + meta: &[1, 2, 3, 4], + }], + } + .to_vec(); + + let query = data::Query::new(&buf).expect("Failed to parse query"); + assert_eq!( + r#"Query { mtype: 81, mlen: 37, query: "SELECT * from foo", types: [QueryType { typ: 1, len: 4, meta: [1, 2, 3, 4] }] }"#, + format!("{query:?}") + ); + } + + #[test] + fn test_fixed_array() { + let buf = builder::Key { + key: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + } + .to_vec(); + + let key = data::Key::new(&buf).expect("Failed to parse key"); + assert_eq!( + key.key(), + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + ); + } + + #[test] + fn test_uuid() { + let buf = builder::Uuids { + uuids: &[Uuid::NAMESPACE_DNS], + } + .to_vec(); + + let uuids = data::Uuids::new(&buf).expect("Failed to parse uuids"); + assert_eq!(uuids.uuids().get(0), Some(Uuid::NAMESPACE_DNS)); + } +} diff --git a/rust/pgrust/src/protocol/writer.rs b/rust/db_proto/src/writer.rs similarity index 100% rename from rust/pgrust/src/protocol/writer.rs rename to rust/db_proto/src/writer.rs diff --git a/rust/pgrust/Cargo.toml b/rust/pgrust/Cargo.toml index 05c6961e53e..e67df1a7da9 100644 --- a/rust/pgrust/Cargo.toml +++ b/rust/pgrust/Cargo.toml @@ -17,6 +17,7 @@ gel_auth.workspace = true pyo3.workspace = true tokio.workspace = true tracing.workspace = true +db_proto.workspace = true futures = "0" thiserror = "1" @@ -30,7 +31,6 @@ url = "2" serde = "1" serde_derive = "1" percent-encoding = "2" -uuid = "1" bytemuck = { version = "1", features = ["derive"] } [dependencies.derive_more] diff --git a/rust/pgrust/src/connection/conn.rs b/rust/pgrust/src/connection/conn.rs index 29a209f366d..05be12d3c91 100644 --- a/rust/pgrust/src/connection/conn.rs +++ b/rust/pgrust/src/connection/conn.rs @@ -11,15 +11,13 @@ use crate::{ ConnectionError, }, handshake::ConnectionSslRequirement, - protocol::{ - postgres::{ - builder, - data::{Message, NotificationResponse, ParameterStatus}, - meta, - }, - StructBuffer, + protocol::postgres::{ + builder, + data::{Message, NotificationResponse, ParameterStatus}, + meta, }, }; +use db_proto::StructBuffer; use futures::{future::Either, FutureExt}; use std::{ cell::RefCell, diff --git a/rust/pgrust/src/connection/flow.rs b/rust/pgrust/src/connection/flow.rs index 15b3628b175..eed27da32eb 100644 --- a/rust/pgrust/src/connection/flow.rs +++ b/rust/pgrust/src/connection/flow.rs @@ -1,8 +1,8 @@ //! Postgres flow notes: //! -//! https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-PIPELINING +//! //! -//! https://segmentfault.com/a/1190000017136059 +//! //! //! Extended query messages Parse, Bind, Describe, Execute, Close put the server //! into a "skip-til-sync" mode when erroring. All messages other than Terminate (including @@ -17,18 +17,15 @@ use std::{cell::RefCell, num::NonZeroU32, rc::Rc}; -use crate::protocol::{ - match_message, - postgres::{ - builder, - data::{ - BindComplete, CloseComplete, CommandComplete, CopyData, CopyDone, CopyOutResponse, - DataRow, EmptyQueryResponse, ErrorResponse, Message, NoData, NoticeResponse, - ParameterDescription, ParseComplete, PortalSuspended, ReadyForQuery, RowDescription, - }, +use crate::protocol::postgres::{ + builder, + data::{ + BindComplete, CloseComplete, CommandComplete, CopyData, CopyDone, CopyOutResponse, DataRow, + EmptyQueryResponse, ErrorResponse, Message, NoData, NoticeResponse, ParameterDescription, + ParseComplete, PortalSuspended, ReadyForQuery, RowDescription, }, - Encoded, }; +use db_proto::{match_message, Encoded}; #[derive(Debug, Clone, Copy)] pub enum Param<'a> { diff --git a/rust/pgrust/src/connection/mod.rs b/rust/pgrust/src/connection/mod.rs index c0a7e929759..696367a9a0f 100644 --- a/rust/pgrust/src/connection/mod.rs +++ b/rust/pgrust/src/connection/mod.rs @@ -1,10 +1,7 @@ use std::collections::HashMap; -use crate::{ - errors::{edgedb::EdbError, PgServerError}, - protocol::ParseError, -}; - +use crate::errors::{edgedb::EdbError, PgServerError}; +use db_proto::ParseError; mod conn; pub mod dsn; mod flow; diff --git a/rust/pgrust/src/connection/raw_conn.rs b/rust/pgrust/src/connection/raw_conn.rs index dbd0930976e..07eaeec5e7c 100644 --- a/rust/pgrust/src/connection/raw_conn.rs +++ b/rust/pgrust/src/connection/raw_conn.rs @@ -10,7 +10,8 @@ use crate::handshake::{ ConnectionSslRequirement, }; use crate::protocol::postgres::{FrontendBuilder, InitialBuilder}; -use crate::protocol::{postgres::data::SSLResponse, postgres::meta, StructBuffer}; +use crate::protocol::{postgres::data::SSLResponse, postgres::meta}; +use db_proto::StructBuffer; use gel_auth::AuthType; use std::collections::HashMap; use std::pin::Pin; diff --git a/rust/pgrust/src/errors/mod.rs b/rust/pgrust/src/errors/mod.rs index 39e7e538e28..a414b2ee112 100644 --- a/rust/pgrust/src/errors/mod.rs +++ b/rust/pgrust/src/errors/mod.rs @@ -135,7 +135,7 @@ macro_rules! pg_error { )* paste!( - /// Postgres error codes. See https://www.postgresql.org/docs/current/errcodes-appendix.html. + /// Postgres error codes. See . #[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] pub enum PgError { $( diff --git a/rust/pgrust/src/handshake/client_state_machine.rs b/rust/pgrust/src/handshake/client_state_machine.rs index 7723f22b3d9..e0fa6583ea0 100644 --- a/rust/pgrust/src/handshake/client_state_machine.rs +++ b/rust/pgrust/src/handshake/client_state_machine.rs @@ -3,7 +3,6 @@ use crate::{ connection::{invalid_state, ConnectionError, Credentials, SslError}, errors::PgServerError, protocol::{ - match_message, postgres::data::{ AuthenticationCleartextPassword, AuthenticationMD5Password, AuthenticationMessage, AuthenticationOk, AuthenticationSASL, AuthenticationSASLContinue, @@ -11,10 +10,10 @@ use crate::{ ReadyForQuery, SSLResponse, }, postgres::{builder, FrontendBuilder, InitialBuilder}, - ParseError, }, }; use base64::Engine; +use db_proto::{match_message, ParseError}; use gel_auth::{ scram::{generate_salted_password, ClientEnvironment, ClientTransaction, Sha256Out}, AuthType, diff --git a/rust/pgrust/src/handshake/edgedb_server.rs b/rust/pgrust/src/handshake/edgedb_server.rs index faf3669a2e1..c5cf6915ab7 100644 --- a/rust/pgrust/src/handshake/edgedb_server.rs +++ b/rust/pgrust/src/handshake/edgedb_server.rs @@ -1,11 +1,9 @@ use crate::{ connection::ConnectionError, errors::edgedb::EdbError, - protocol::{ - edgedb::{data::*, *}, - match_message, ParseError, StructBuffer, - }, + protocol::edgedb::{data::*, *}, }; +use db_proto::{match_message, ParseError, StructBuffer}; use gel_auth::{ handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse}, AuthType, CredentialData, diff --git a/rust/pgrust/src/handshake/server_state_machine.rs b/rust/pgrust/src/handshake/server_state_machine.rs index 8c62c00b610..f7aa30f316f 100644 --- a/rust/pgrust/src/handshake/server_state_machine.rs +++ b/rust/pgrust/src/handshake/server_state_machine.rs @@ -5,12 +5,9 @@ use crate::{ PgError, PgErrorConnectionException, PgErrorFeatureNotSupported, PgErrorInvalidAuthorizationSpecification, PgServerError, PgServerErrorField, }, - protocol::{ - match_message, - postgres::{data::*, *}, - ParseError, StructBuffer, - }, + protocol::postgres::{data::*, *}, }; +use db_proto::{match_message, ParseError, StructBuffer}; use gel_auth::{ handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse}, AuthType, CredentialData, diff --git a/rust/pgrust/src/protocol/datatypes.rs b/rust/pgrust/src/protocol/datatypes.rs deleted file mode 100644 index b06b98575ea..00000000000 --- a/rust/pgrust/src/protocol/datatypes.rs +++ /dev/null @@ -1,788 +0,0 @@ -use std::{marker::PhantomData, str::Utf8Error}; - -use uuid::Uuid; - -use super::{ - arrays::{array_access, Array, ArrayMeta}, - field_access, - writer::BufWriter, - Enliven, FieldAccess, Meta, ParseError, -}; - -pub mod meta { - pub use super::EncodedMeta as Encoded; - pub use super::LStringMeta as LString; - pub use super::LengthMeta as Length; - pub use super::RestMeta as Rest; - pub use super::UuidMeta as Uuid; - pub use super::ZTStringMeta as ZTString; -} - -/// Represents the remainder of data in a message. -#[derive(Debug, PartialEq, Eq)] -pub struct Rest<'a> { - buf: &'a [u8], -} - -field_access!(RestMeta); - -pub struct RestMeta {} -impl Meta for RestMeta { - fn name(&self) -> &'static str { - "Rest" - } -} -impl Enliven for RestMeta { - type WithLifetime<'a> = Rest<'a>; - type ForMeasure<'a> = &'a [u8]; - type ForBuilder<'a> = &'a [u8]; -} - -impl<'a> Rest<'a> {} - -impl<'a> AsRef<[u8]> for Rest<'a> { - fn as_ref(&self) -> &[u8] { - self.buf - } -} - -impl<'a> std::ops::Deref for Rest<'a> { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - self.buf - } -} - -impl PartialEq<[u8]> for Rest<'_> { - fn eq(&self, other: &[u8]) -> bool { - self.buf == other - } -} - -impl PartialEq<&[u8; N]> for Rest<'_> { - fn eq(&self, other: &&[u8; N]) -> bool { - self.buf == *other - } -} - -impl PartialEq<&[u8]> for Rest<'_> { - fn eq(&self, other: &&[u8]) -> bool { - self.buf == *other - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &RestMeta {} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - Ok(buf.len()) - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result, ParseError> { - Ok(Rest { buf }) - } - #[inline(always)] - pub const fn measure(buf: &[u8]) -> usize { - buf.len() - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: &[u8]) { - buf.write(value) - } -} - -/// A zero-terminated string. -#[allow(unused)] -pub struct ZTString<'a> { - buf: &'a [u8], -} - -field_access!(ZTStringMeta); -array_access!(ZTStringMeta); - -pub struct ZTStringMeta {} -impl Meta for ZTStringMeta { - fn name(&self) -> &'static str { - "ZTString" - } -} - -impl Enliven for ZTStringMeta { - type WithLifetime<'a> = ZTString<'a>; - type ForMeasure<'a> = &'a str; - type ForBuilder<'a> = &'a str; -} - -impl std::fmt::Debug for ZTString<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - String::from_utf8_lossy(self.buf).fmt(f) - } -} - -impl<'a> ZTString<'a> { - pub fn to_owned(&self) -> Result { - std::str::from_utf8(self.buf).map(|s| s.to_owned()) - } - - pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { - std::str::from_utf8(self.buf) - } - - pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { - String::from_utf8_lossy(self.buf) - } - - pub fn to_bytes(&self) -> &[u8] { - self.buf - } -} - -impl PartialEq for ZTString<'_> { - fn eq(&self, other: &Self) -> bool { - self.buf == other.buf - } -} -impl Eq for ZTString<'_> {} - -impl PartialEq for ZTString<'_> { - fn eq(&self, other: &str) -> bool { - self.buf == other.as_bytes() - } -} - -impl PartialEq<&str> for ZTString<'_> { - fn eq(&self, other: &&str) -> bool { - self.buf == other.as_bytes() - } -} - -impl<'a> TryInto<&'a str> for ZTString<'a> { - type Error = Utf8Error; - fn try_into(self) -> Result<&'a str, Self::Error> { - std::str::from_utf8(self.buf) - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &ZTStringMeta {} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - let mut i = 0; - loop { - if i >= buf.len() { - return Err(ParseError::TooShort); - } - if buf[i] == 0 { - return Ok(i + 1); - } - i += 1; - } - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result, ParseError> { - let buf = buf.split_at(buf.len() - 1).0; - Ok(ZTString { buf }) - } - #[inline(always)] - pub const fn measure(buf: &str) -> usize { - buf.len() + 1 - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { - buf.write(value.as_bytes()); - buf.write_u8(0); - } - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &str) { - buf.write(value.as_bytes()); - buf.write_u8(0); - } -} - -/// A length-prefixed string. -#[allow(unused)] -pub struct LString<'a> { - buf: &'a [u8], -} - -field_access!(LStringMeta); -array_access!(LStringMeta); - -pub struct LStringMeta {} -impl Meta for LStringMeta { - fn name(&self) -> &'static str { - "LString" - } -} - -impl Enliven for LStringMeta { - type WithLifetime<'a> = LString<'a>; - type ForMeasure<'a> = &'a str; - type ForBuilder<'a> = &'a str; -} - -impl std::fmt::Debug for LString<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - String::from_utf8_lossy(self.buf).fmt(f) - } -} - -impl<'a> LString<'a> { - pub fn to_owned(&self) -> Result { - std::str::from_utf8(self.buf).map(|s| s.to_owned()) - } - - pub fn to_str(&self) -> Result<&str, std::str::Utf8Error> { - std::str::from_utf8(self.buf) - } - - pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { - String::from_utf8_lossy(self.buf) - } - - pub fn to_bytes(&self) -> &[u8] { - self.buf - } -} - -impl PartialEq for LString<'_> { - fn eq(&self, other: &Self) -> bool { - self.buf == other.buf - } -} -impl Eq for LString<'_> {} - -impl PartialEq for LString<'_> { - fn eq(&self, other: &str) -> bool { - self.buf == other.as_bytes() - } -} - -impl PartialEq<&str> for LString<'_> { - fn eq(&self, other: &&str) -> bool { - self.buf == other.as_bytes() - } -} - -impl<'a> TryInto<&'a str> for LString<'a> { - type Error = Utf8Error; - fn try_into(self) -> Result<&'a str, Self::Error> { - std::str::from_utf8(self.buf) - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &LStringMeta {} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - if buf.len() < 4 { - return Err(ParseError::TooShort); - } - let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; - Ok(4 + len) - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result, ParseError> { - if buf.len() < 4 { - return Err(ParseError::TooShort); - } - let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; - if buf.len() < 4 + len { - return Err(ParseError::TooShort); - } - Ok(LString { - buf: buf.split_at(4).1, - }) - } - #[inline(always)] - pub const fn measure(buf: &str) -> usize { - 4 + buf.len() - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: &str) { - let len = value.len() as u32; - buf.write(&len.to_be_bytes()); - buf.write(value.as_bytes()); - } - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &str) { - let len = value.len() as u32; - buf.write(&len.to_be_bytes()); - buf.write(value.as_bytes()); - } -} - -field_access!(UuidMeta); -array_access!(UuidMeta); - -pub struct UuidMeta {} -impl Meta for UuidMeta { - fn name(&self) -> &'static str { - "Uuid" - } -} - -impl Enliven for UuidMeta { - type WithLifetime<'a> = Uuid; - type ForMeasure<'a> = Uuid; - type ForBuilder<'a> = Uuid; -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &UuidMeta {} - } - - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - if buf.len() < 16 { - Err(ParseError::TooShort) - } else { - Ok(16) - } - } - - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result { - if let Some(bytes) = buf.first_chunk() { - Ok(Uuid::from_u128(::from_be_bytes(*bytes))) - } else { - Err(ParseError::TooShort) - } - } - - #[inline(always)] - pub const fn measure(_value: &Uuid) -> usize { - 16 - } - - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: Uuid) { - buf.write(value.as_bytes().as_slice()); - } - - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &Uuid) { - buf.write(value.as_bytes().as_slice()); - } -} - -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] -/// An encoded row value. -pub enum Encoded<'a> { - #[default] - Null, - Value(&'a [u8]), -} - -impl<'a> Encoded<'a> { - pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { - match self { - Encoded::Null => "".into(), - Encoded::Value(value) => String::from_utf8_lossy(value), - } - } -} - -impl<'a> AsRef> for Encoded<'a> { - fn as_ref(&self) -> &Encoded<'a> { - self - } -} - -field_access!(EncodedMeta); -array_access!(EncodedMeta); - -pub struct EncodedMeta {} -impl Meta for EncodedMeta { - fn name(&self) -> &'static str { - "Encoded" - } -} - -impl Enliven for EncodedMeta { - type WithLifetime<'a> = Encoded<'a>; - type ForMeasure<'a> = Encoded<'a>; - type ForBuilder<'a> = Encoded<'a>; -} - -impl<'a> Encoded<'a> {} - -impl PartialEq for Encoded<'_> { - fn eq(&self, other: &str) -> bool { - self == &Encoded::Value(other.as_bytes()) - } -} - -impl PartialEq<&str> for Encoded<'_> { - fn eq(&self, other: &&str) -> bool { - self == &Encoded::Value(other.as_bytes()) - } -} - -impl PartialEq<[u8]> for Encoded<'_> { - fn eq(&self, other: &[u8]) -> bool { - self == &Encoded::Value(other) - } -} - -impl PartialEq<&[u8]> for Encoded<'_> { - fn eq(&self, other: &&[u8]) -> bool { - self == &Encoded::Value(other) - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &EncodedMeta {} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - const N: usize = std::mem::size_of::(); - if let Some(len) = buf.first_chunk::() { - let len = i32::from_be_bytes(*len); - if len == -1 { - Ok(N) - } else if len < 0 { - Err(ParseError::InvalidData) - } else if buf.len() < len as usize + N { - Err(ParseError::TooShort) - } else { - Ok(len as usize + N) - } - } else { - Err(ParseError::TooShort) - } - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result, ParseError> { - const N: usize = std::mem::size_of::(); - if let Some((len, array)) = buf.split_first_chunk::() { - let len = i32::from_be_bytes(*len); - if len == -1 && array.is_empty() { - Ok(Encoded::Null) - } else if len < 0 { - Err(ParseError::InvalidData) - } else if array.len() < len as _ { - Err(ParseError::TooShort) - } else { - Ok(Encoded::Value(array)) - } - } else { - Err(ParseError::TooShort) - } - } - #[inline(always)] - pub const fn measure(value: &Encoded) -> usize { - match value { - Encoded::Null => std::mem::size_of::(), - Encoded::Value(value) => value.len() + std::mem::size_of::(), - } - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: Encoded) { - Self::copy_to_buf_ref(buf, &value) - } - #[inline(always)] - pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &Encoded) { - match value { - Encoded::Null => buf.write(&[0xff, 0xff, 0xff, 0xff]), - Encoded::Value(value) => { - let len: i32 = value.len() as _; - buf.write(&len.to_be_bytes()); - buf.write(value); - } - } - } -} - -// We alias usize here. Note that if this causes trouble in the future we can -// probably work around this by adding a new "const value" function to -// FieldAccess. For now it works! -pub struct LengthMeta(#[allow(unused)] i32); -impl Enliven for LengthMeta { - type WithLifetime<'a> = usize; - type ForMeasure<'a> = usize; - type ForBuilder<'a> = usize; -} -impl Meta for LengthMeta { - fn name(&self) -> &'static str { - "len" - } -} - -impl FieldAccess { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - &LengthMeta(0) - } - #[inline(always)] - pub const fn constant(value: usize) -> LengthMeta { - LengthMeta(value as i32) - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - match FieldAccess::::extract(buf) { - Ok(n) if n >= 0 => Ok(std::mem::size_of::()), - Ok(_) => Err(ParseError::InvalidData), - Err(e) => Err(e), - } - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result { - match FieldAccess::::extract(buf) { - Ok(n) if n >= 0 => Ok(n as _), - Ok(_) => Err(ParseError::InvalidData), - Err(e) => Err(e), - } - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: usize) { - FieldAccess::::copy_to_buf(buf, value as i32) - } - #[inline(always)] - pub fn copy_to_buf_rewind(buf: &mut BufWriter, rewind: usize, value: usize) { - FieldAccess::::copy_to_buf_rewind(buf, rewind, value as i32) - } -} - -macro_rules! basic_types { - ($($ty:ty)*) => { - $( - field_access!{$ty} - - impl Enliven for $ty { - type WithLifetime<'a> = $ty; - type ForMeasure<'a> = $ty; - type ForBuilder<'a> = $ty; - } - - impl Enliven for [$ty; S] { - type WithLifetime<'a> = [$ty; S]; - type ForMeasure<'a> = [$ty; S]; - type ForBuilder<'a> = [$ty; S]; - } - - #[allow(unused)] - impl FieldAccess<$ty> { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - struct Meta {} - impl $crate::protocol::Meta for Meta { - fn name(&self) -> &'static str { - stringify!($ty) - } - } - &Meta{} - } - #[inline(always)] - pub const fn constant(value: usize) -> $ty { - value as _ - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - let size = std::mem::size_of::<$ty>(); - if size > buf.len() { - Err($crate::protocol::ParseError::TooShort) - } else { - Ok(size) - } - } - #[inline(always)] - pub const fn extract(buf: &[u8]) -> Result<$ty, $crate::protocol::ParseError> { - if let Some(bytes) = buf.first_chunk() { - Ok(<$ty>::from_be_bytes(*bytes)) - } else { - Err($crate::protocol::ParseError::TooShort) - } - } - #[inline(always)] - pub fn copy_to_buf(buf: &mut BufWriter, value: $ty) { - buf.write(&<$ty>::to_be_bytes(value)); - } - #[inline(always)] - pub fn copy_to_buf_rewind(buf: &mut BufWriter, rewind: usize, value: $ty) { - buf.write_rewind(rewind, &<$ty>::to_be_bytes(value)); - } - } - - #[allow(unused)] - impl FieldAccess<[$ty; S]> { - #[inline(always)] - pub const fn meta() -> &'static dyn Meta { - struct Meta {} - impl $crate::protocol::Meta for Meta { - fn name(&self) -> &'static str { - // TODO: can we extract this constant? - concat!('[', stringify!($ty), "; ", "S") - } - } - &Meta{} - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - let size = std::mem::size_of::<$ty>() * S; - if size > buf.len() { - Err($crate::protocol::ParseError::TooShort) - } else { - Ok(size) - } - } - #[inline(always)] - pub const fn extract(mut buf: &[u8]) -> Result<[$ty; S], $crate::protocol::ParseError> { - let mut out: [$ty; S] = [0; S]; - let mut i = 0; - loop { - if i == S { - break; - } - (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { - (<$ty>::from_be_bytes(*bytes), rest) - } else { - return Err($crate::protocol::ParseError::TooShort) - }; - i += 1; - } - Ok(out) - } - #[inline(always)] - pub fn copy_to_buf(mut buf: &mut BufWriter, value: [$ty; S]) { - if !buf.test(std::mem::size_of::<$ty>() * S) { - return; - } - for n in value { - buf.write(&<$ty>::to_be_bytes(n)); - } - } - } - - impl $crate::protocol::FixedSize for $ty { - const SIZE: usize = std::mem::size_of::<$ty>(); - #[inline(always)] - fn extract_infallible(buf: &[u8]) -> $ty { - if let Some(buf) = buf.first_chunk() { - <$ty>::from_be_bytes(*buf) - } else { - panic!() - } - } - } - impl $crate::protocol::FixedSize for [$ty; S] { - const SIZE: usize = std::mem::size_of::<$ty>() * S; - #[inline(always)] - fn extract_infallible(mut buf: &[u8]) -> [$ty; S] { - let mut out: [$ty; S] = [0; S]; - let mut i = 0; - loop { - if i == S { - break; - } - (out[i], buf) = if let Some((bytes, rest)) = buf.split_first_chunk() { - (<$ty>::from_be_bytes(*bytes), rest) - } else { - panic!() - }; - i += 1; - } - out - } - } - - basic_types!(: array<$ty> u8 i16 i32 u32 u64); - )* - }; - - (: array<$ty:ty> $($len:ty)*) => { - $( - #[allow(unused)] - impl FieldAccess> { - pub const fn meta() -> &'static dyn Meta { - &ArrayMeta::<$len, $ty> { _phantom: PhantomData } - } - #[inline(always)] - pub const fn size_of_field_at(buf: &[u8]) -> Result { - const N: usize = std::mem::size_of::<$ty>(); - const L: usize = std::mem::size_of::<$len>(); - if let Some(len) = buf.first_chunk::() { - let len_value = <$len>::from_be_bytes(*len); - #[allow(unused_comparisons)] - if len_value < 0 { - return Err($crate::protocol::ParseError::InvalidData); - } - let mut byte_len = len_value as usize; - byte_len = match byte_len.checked_mul(N) { - Some(l) => l, - None => return Err($crate::protocol::ParseError::TooShort), - }; - byte_len = match byte_len.checked_add(L) { - Some(l) => l, - None => return Err($crate::protocol::ParseError::TooShort), - }; - if buf.len() < byte_len { - Err($crate::protocol::ParseError::TooShort) - } else { - Ok(byte_len) - } - } else { - Err($crate::protocol::ParseError::TooShort) - } - } - #[inline(always)] - pub const fn extract(mut buf: &[u8]) -> Result, $crate::protocol::ParseError> { - const N: usize = std::mem::size_of::<$ty>(); - const L: usize = std::mem::size_of::<$len>(); - if let Some((len, array)) = buf.split_first_chunk::() { - let len_value = <$len>::from_be_bytes(*len); - #[allow(unused_comparisons)] - if len_value < 0 { - return Err($crate::protocol::ParseError::InvalidData); - } - let mut byte_len = len_value as usize; - byte_len = match byte_len.checked_mul(N) { - Some(l) => l, - None => return Err($crate::protocol::ParseError::TooShort), - }; - byte_len = match byte_len.checked_add(L) { - Some(l) => l, - None => return Err($crate::protocol::ParseError::TooShort), - }; - if buf.len() < byte_len { - Err($crate::protocol::ParseError::TooShort) - } else { - Ok(Array::new(array, <$len>::from_be_bytes(*len) as u32)) - } - } else { - Err($crate::protocol::ParseError::TooShort) - } - } - #[inline(always)] - pub const fn measure(buffer: &[$ty]) -> usize { - buffer.len() * std::mem::size_of::<$ty>() + std::mem::size_of::<$len>() - } - #[inline(always)] - pub fn copy_to_buf(mut buf: &mut BufWriter, value: &[$ty]) { - let size: usize = std::mem::size_of::<$ty>() * value.len() + std::mem::size_of::<$len>(); - if !buf.test(size) { - return; - } - buf.write(&<$len>::to_be_bytes(value.len() as _)); - for n in value { - buf.write(&<$ty>::to_be_bytes(*n)); - } - } - } - )* - } -} -basic_types!(u8 i16 i32 u32 u64); diff --git a/rust/pgrust/src/protocol/definition.rs b/rust/pgrust/src/protocol/definition.rs deleted file mode 100644 index bef36d5aa65..00000000000 --- a/rust/pgrust/src/protocol/definition.rs +++ /dev/null @@ -1,740 +0,0 @@ -use super::gen::protocol; -use super::message_group::message_group; -use crate::protocol::meta::*; - -message_group!( - /// The `Backend` message group contains messages sent from the backend to the frontend. - Backend: Message = [ - AuthenticationOk, - AuthenticationKerberosV5, - AuthenticationCleartextPassword, - AuthenticationMD5Password, - AuthenticationGSS, - AuthenticationGSSContinue, - AuthenticationSSPI, - AuthenticationSASL, - AuthenticationSASLContinue, - AuthenticationSASLFinal, - BackendKeyData, - BindComplete, - CloseComplete, - CommandComplete, - CopyData, - CopyDone, - CopyInResponse, - CopyOutResponse, - CopyBothResponse, - DataRow, - EmptyQueryResponse, - ErrorResponse, - FunctionCallResponse, - NegotiateProtocolVersion, - NoData, - NoticeResponse, - NotificationResponse, - ParameterDescription, - ParameterStatus, - ParseComplete, - PortalSuspended, - ReadyForQuery, - RowDescription - ] -); - -message_group!( - /// The `Frontend` message group contains messages sent from the frontend to the backend. - Frontend: Message = [ - Bind, - Close, - CopyData, - CopyDone, - CopyFail, - Describe, - Execute, - Flush, - FunctionCall, - GSSResponse, - Parse, - PasswordMessage, - Query, - SASLInitialResponse, - SASLResponse, - Sync, - Terminate - ] -); - -message_group!( - /// The `Initial` message group contains messages that are sent before the - /// normal message flow. - Initial: InitialMessage = [ - CancelRequest, - GSSENCRequest, - SSLRequest, - StartupMessage - ] -); - -protocol!( - -/// A generic base for all Postgres mtype/mlen-style messages. -struct Message { - /// Identifies the message. - mtype: u8, - /// Length of message contents in bytes, including self. - mlen: len, - /// Message contents. - data: Rest, -} - -/// A generic base for all initial Postgres messages. -struct InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len, - /// The identifier for this initial message. - protocol_version: i32, - /// Message contents. - data: Rest -} - -/// The `AuthenticationMessage` struct is a base for all Postgres authentication messages. -struct AuthenticationMessage: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that the authentication was successful. - status: i32, -} - -/// The `AuthenticationOk` struct represents a message indicating successful authentication. -struct AuthenticationOk: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that the authentication was successful. - status: i32 = 0, -} - -/// The `AuthenticationKerberosV5` struct represents a message indicating that Kerberos V5 authentication is required. -struct AuthenticationKerberosV5: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that Kerberos V5 authentication is required. - status: i32 = 2, -} - -/// The `AuthenticationCleartextPassword` struct represents a message indicating that a cleartext password is required for authentication. -struct AuthenticationCleartextPassword: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that a clear-text password is required. - status: i32 = 3, -} - -/// The `AuthenticationMD5Password` struct represents a message indicating that an MD5-encrypted password is required for authentication. -struct AuthenticationMD5Password: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 12, - /// Specifies that an MD5-encrypted password is required. - status: i32 = 5, - /// The salt to use when encrypting the password. - salt: [u8; 4], -} - -/// The `AuthenticationSCMCredential` struct represents a message indicating that an SCM credential is required for authentication. -struct AuthenticationSCMCredential: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 6, - /// Any data byte, which is ignored. - byte: u8 = 0, -} - -/// The `AuthenticationGSS` struct represents a message indicating that GSSAPI authentication is required. -struct AuthenticationGSS: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that GSSAPI authentication is required. - status: i32 = 7, -} - -/// The `AuthenticationGSSContinue` struct represents a message indicating the continuation of GSSAPI authentication. -struct AuthenticationGSSContinue: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that this message contains GSSAPI or SSPI data. - status: i32 = 8, - /// GSSAPI or SSPI authentication data. - data: Rest, -} - -/// The `AuthenticationSSPI` struct represents a message indicating that SSPI authentication is required. -struct AuthenticationSSPI: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// Specifies that SSPI authentication is required. - status: i32 = 9, -} - -/// The `AuthenticationSASL` struct represents a message indicating that SASL authentication is required. -struct AuthenticationSASL: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that SASL authentication is required. - status: i32 = 10, - /// List of SASL authentication mechanisms, terminated by a zero byte. - mechanisms: ZTArray, -} - -/// The `AuthenticationSASLContinue` struct represents a message containing a SASL challenge during the authentication process. -struct AuthenticationSASLContinue: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that this message contains a SASL challenge. - status: i32 = 11, - /// SASL data, specific to the SASL mechanism being used. - data: Rest, -} - -/// The `AuthenticationSASLFinal` struct represents a message indicating the completion of SASL authentication. -struct AuthenticationSASLFinal: Message { - /// Identifies the message as an authentication request. - mtype: u8 = 'R', - /// Length of message contents in bytes, including self. - mlen: len, - /// Specifies that SASL authentication has completed. - status: i32 = 12, - /// SASL outcome "additional data", specific to the SASL mechanism being used. - data: Rest, -} - -/// The `BackendKeyData` struct represents a message containing the process ID and secret key for this backend. -struct BackendKeyData: Message { - /// Identifies the message as cancellation key data. - mtype: u8 = 'K', - /// Length of message contents in bytes, including self. - mlen: len = 12, - /// The process ID of this backend. - pid: i32, - /// The secret key of this backend. - key: i32, -} - -/// The `Bind` struct represents a message to bind a named portal to a prepared statement. -struct Bind: Message { - /// Identifies the message as a Bind command. - mtype: u8 = 'B', - /// Length of message contents in bytes, including self. - mlen: len, - /// The name of the destination portal. - portal: ZTString, - /// The name of the source prepared statement. - statement: ZTString, - /// The parameter format codes. - format_codes: Array, - /// Array of parameter values and their lengths. - values: Array, - /// The result-column format codes. - result_format_codes: Array, -} - -/// The `BindComplete` struct represents a message indicating that a Bind operation was successful. -struct BindComplete: Message { - /// Identifies the message as a Bind-complete indicator. - mtype: u8 = '2', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `CancelRequest` struct represents a message to request the cancellation of a query. -struct CancelRequest: InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len = 16, - /// The cancel request code. - code: i32 = 80877102, - /// The process ID of the target backend. - pid: i32, - /// The secret key for the target backend. - key: i32, -} - -/// The `Close` struct represents a message to close a prepared statement or portal. -struct Close: Message { - /// Identifies the message as a Close command. - mtype: u8 = 'C', - /// Length of message contents in bytes, including self. - mlen: len, - /// 'xS' to close a prepared statement; 'P' to close a portal. - ctype: u8, - /// The name of the prepared statement or portal to close. - name: ZTString, -} - -/// The `CloseComplete` struct represents a message indicating that a Close operation was successful. -struct CloseComplete: Message { - /// Identifies the message as a Close-complete indicator. - mtype: u8 = '3', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `CommandComplete` struct represents a message indicating the successful completion of a command. -struct CommandComplete: Message { - /// Identifies the message as a command-completed response. - mtype: u8 = 'C', - /// Length of message contents in bytes, including self. - mlen: len, - /// The command tag. - tag: ZTString, -} - -/// The `CopyData` struct represents a message containing data for a copy operation. -struct CopyData: Message { - /// Identifies the message as COPY data. - mtype: u8 = 'd', - /// Length of message contents in bytes, including self. - mlen: len, - /// Data that forms part of a COPY data stream. - data: Rest, -} - -/// The `CopyDone` struct represents a message indicating that a copy operation is complete. -struct CopyDone: Message { - /// Identifies the message as a COPY-complete indicator. - mtype: u8 = 'c', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `CopyFail` struct represents a message indicating that a copy operation has failed. -struct CopyFail: Message { - /// Identifies the message as a COPY-failure indicator. - mtype: u8 = 'f', - /// Length of message contents in bytes, including self. - mlen: len, - /// An error message to report as the cause of failure. - error_msg: ZTString, -} - -/// The `CopyInResponse` struct represents a message indicating that the server is ready to receive data for a copy-in operation. -struct CopyInResponse: Message { - /// Identifies the message as a Start Copy In response. - mtype: u8 = 'G', - /// Length of message contents in bytes, including self. - mlen: len, - /// 0 for textual, 1 for binary. - format: u8, - /// The format codes for each column. - format_codes: Array, -} - -/// The `CopyOutResponse` struct represents a message indicating that the server is ready to send data for a copy-out operation. -struct CopyOutResponse: Message { - /// Identifies the message as a Start Copy Out response. - mtype: u8 = 'H', - /// Length of message contents in bytes, including self. - mlen: len, - /// 0 for textual, 1 for binary. - format: u8, - /// The format codes for each column. - format_codes: Array, -} - -/// The `CopyBothResponse` is used only for Streaming Replication. -struct CopyBothResponse: Message { - /// Identifies the message as a Start Copy Both response. - mtype: u8 = 'W', - /// Length of message contents in bytes, including self. - mlen: len, - /// 0 for textual, 1 for binary. - format: u8, - /// The format codes for each column. - format_codes: Array, -} - -/// The `DataRow` struct represents a message containing a row of data. -struct DataRow: Message { - /// Identifies the message as a data row. - mtype: u8 = 'D', - /// Length of message contents in bytes, including self. - mlen: len, - /// Array of column values and their lengths. - values: Array, -} - -/// The `Describe` struct represents a message to describe a prepared statement or portal. -struct Describe: Message { - /// Identifies the message as a Describe command. - mtype: u8 = 'D', - /// Length of message contents in bytes, including self. - mlen: len, - /// 'S' to describe a prepared statement; 'P' to describe a portal. - dtype: u8, - /// The name of the prepared statement or portal. - name: ZTString, -} - -/// The `EmptyQueryResponse` struct represents a message indicating that an empty query string was recognized. -struct EmptyQueryResponse: Message { - /// Identifies the message as a response to an empty query String. - mtype: u8 = 'I', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `ErrorResponse` struct represents a message indicating that an error has occurred. -struct ErrorResponse: Message { - /// Identifies the message as an error. - mtype: u8 = 'E', - /// Length of message contents in bytes, including self. - mlen: len, - /// Array of error fields and their values. - fields: ZTArray, -} - -/// The `ErrorField` struct represents a single error message within an `ErrorResponse`. -struct ErrorField { - /// A code identifying the field type. - etype: u8, - /// The field value. - value: ZTString, -} - -/// The `Execute` struct represents a message to execute a prepared statement or portal. -struct Execute: Message { - /// Identifies the message as an Execute command. - mtype: u8 = 'E', - /// Length of message contents in bytes, including self. - mlen: len, - /// The name of the portal to execute. - portal: ZTString, - /// Maximum number of rows to return. - max_rows: i32, -} - -/// The `Flush` struct represents a message to flush the backend's output buffer. -struct Flush: Message { - /// Identifies the message as a Flush command. - mtype: u8 = 'H', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `FunctionCall` struct represents a message to call a function. -struct FunctionCall: Message { - /// Identifies the message as a function call. - mtype: u8 = 'F', - /// Length of message contents in bytes, including self. - mlen: len, - /// OID of the function to execute. - function_id: i32, - /// The parameter format codes. - format_codes: Array, - /// Array of args and their lengths. - args: Array, - /// The format code for the result. - result_format_code: i16, -} - -/// The `FunctionCallResponse` struct represents a message containing the result of a function call. -struct FunctionCallResponse: Message { - /// Identifies the message as a function-call response. - mtype: u8 = 'V', - /// Length of message contents in bytes, including self. - mlen: len, - /// The function result value. - result: Encoded, -} - -/// The `GSSENCRequest` struct represents a message requesting GSSAPI encryption. -struct GSSENCRequest: InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// The GSSAPI Encryption request code. - gssenc_request_code: i32 = 80877104, -} - -/// The `GSSResponse` struct represents a message containing a GSSAPI or SSPI response. -struct GSSResponse: Message { - /// Identifies the message as a GSSAPI or SSPI response. - mtype: u8 = 'p', - /// Length of message contents in bytes, including self. - mlen: len, - /// GSSAPI or SSPI authentication data. - data: Rest, -} - -/// The `NegotiateProtocolVersion` struct represents a message requesting protocol version negotiation. -struct NegotiateProtocolVersion: Message { - /// Identifies the message as a protocol version negotiation request. - mtype: u8 = 'v', - /// Length of message contents in bytes, including self. - mlen: len, - /// Newest minor protocol version supported by the server. - minor_version: i32, - /// List of protocol options not recognized. - options: Array, -} - -/// The `NoData` struct represents a message indicating that there is no data to return. -struct NoData: Message { - /// Identifies the message as a No Data indicator. - mtype: u8 = 'n', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `NoticeResponse` struct represents a message containing a notice. -struct NoticeResponse: Message { - /// Identifies the message as a notice. - mtype: u8 = 'N', - /// Length of message contents in bytes, including self. - mlen: len, - /// Array of notice fields and their values. - fields: ZTArray, -} - -/// The `NoticeField` struct represents a single error message within an `NoticeResponse`. -struct NoticeField: Message { - /// A code identifying the field type. - ntype: u8, - /// The field value. - value: ZTString, -} - -/// The `NotificationResponse` struct represents a message containing a notification from the backend. -struct NotificationResponse: Message { - /// Identifies the message as a notification. - mtype: u8 = 'A', - /// Length of message contents in bytes, including self. - mlen: len, - /// The process ID of the notifying backend. - pid: i32, - /// The name of the notification channel. - channel: ZTString, - /// The notification payload. - payload: ZTString, -} - -/// The `ParameterDescription` struct represents a message describing the parameters needed by a prepared statement. -struct ParameterDescription: Message { - /// Identifies the message as a parameter description. - mtype: u8 = 't', - /// Length of message contents in bytes, including self. - mlen: len, - /// OIDs of the parameter data types. - param_types: Array, -} - -/// The `ParameterStatus` struct represents a message containing the current status of a parameter. -struct ParameterStatus: Message { - /// Identifies the message as a runtime parameter status report. - mtype: u8 = 'S', - /// Length of message contents in bytes, including self. - mlen: len, - /// The name of the parameter. - name: ZTString, - /// The current value of the parameter. - value: ZTString, -} - -/// The `Parse` struct represents a message to parse a query string. -struct Parse: Message { - /// Identifies the message as a Parse command. - mtype: u8 = 'P', - /// Length of message contents in bytes, including self. - mlen: len, - /// The name of the destination prepared statement. - statement: ZTString, - /// The query String to be parsed. - query: ZTString, - /// OIDs of the parameter data types. - param_types: Array, -} - -/// The `ParseComplete` struct represents a message indicating that a Parse operation was successful. -struct ParseComplete: Message { - /// Identifies the message as a Parse-complete indicator. - mtype: u8 = '1', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `PasswordMessage` struct represents a message containing a password. -struct PasswordMessage: Message { - /// Identifies the message as a password response. - mtype: u8 = 'p', - /// Length of message contents in bytes, including self. - mlen: len, - /// The password (encrypted or plaintext, depending on context). - password: ZTString, -} - -/// The `PortalSuspended` struct represents a message indicating that a portal has been suspended. -struct PortalSuspended: Message { - /// Identifies the message as a portal-suspended indicator. - mtype: u8 = 's', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `Query` struct represents a message to execute a simple query. -struct Query: Message { - /// Identifies the message as a simple query command. - mtype: u8 = 'Q', - /// Length of message contents in bytes, including self. - mlen: len, - /// The query String to be executed. - query: ZTString, -} - -/// The `ReadyForQuery` struct represents a message indicating that the backend is ready for a new query. -struct ReadyForQuery: Message { - /// Identifies the message as a ready-for-query indicator. - mtype: u8 = 'Z', - /// Length of message contents in bytes, including self. - mlen: len = 5, - /// Current transaction status indicator. - status: u8, -} - -/// The `RowDescription` struct represents a message describing the rows that will be returned by a query. -struct RowDescription: Message { - /// Identifies the message as a row description. - mtype: u8 = 'T', - /// Length of message contents in bytes, including self. - mlen: len, - /// Array of field descriptions. - fields: Array, -} - -/// The `RowField` struct represents a row within the `RowDescription` message. -struct RowField { - /// The field name - name: ZTString, - /// The table ID (OID) of the table the column is from, or 0 if not a column reference - table_oid: i32, - /// The attribute number of the column, or 0 if not a column reference - column_attr_number: i16, - /// The object ID of the field's data type - data_type_oid: i32, - /// The data type size (negative if variable size) - data_type_size: i16, - /// The type modifier - type_modifier: i32, - /// The format code being used for the field (0 for text, 1 for binary) - format_code: i16, -} - -/// The `SASLInitialResponse` struct represents a message containing a SASL initial response. -struct SASLInitialResponse: Message { - /// Identifies the message as a SASL initial response. - mtype: u8 = 'p', - /// Length of message contents in bytes, including self. - mlen: len, - /// Name of the SASL authentication mechanism. - mechanism: ZTString, - /// SASL initial response data. - response: Array, -} - -/// The `SASLResponse` struct represents a message containing a SASL response. -struct SASLResponse: Message { - /// Identifies the message as a SASL response. - mtype: u8 = 'p', - /// Length of message contents in bytes, including self. - mlen: len, - /// SASL response data. - response: Rest, -} - -/// The `SSLRequest` struct represents a message requesting SSL encryption. -struct SSLRequest: InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len = 8, - /// The SSL request code. - code: i32 = 80877103, -} - -struct SSLResponse { - /// Specifies if SSL was accepted or rejected. - code: u8, -} - -/// The `StartupMessage` struct represents a message to initiate a connection. -struct StartupMessage: InitialMessage { - /// Length of message contents in bytes, including self. - mlen: len, - /// The protocol version number. - protocol: i32 = 196608, - /// List of parameter name-value pairs, terminated by a zero byte. - params: ZTArray, -} - -/// The `StartupMessage` struct represents a name/value pair within the `StartupMessage` message. -struct StartupNameValue { - /// The parameter name. - name: ZTString, - /// The parameter value. - value: ZTString, -} - -/// The `Sync` struct represents a message to synchronize the frontend and backend. -struct Sync: Message { - /// Identifies the message as a Sync command. - mtype: u8 = 'S', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} - -/// The `Terminate` struct represents a message to terminate a connection. -struct Terminate: Message { - /// Identifies the message as a Terminate command. - mtype: u8 = 'X', - /// Length of message contents in bytes, including self. - mlen: len = 4, -} -); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_all() { - let message = meta::Message::default(); - let initial_message = meta::InitialMessage::default(); - - for meta in meta::ALL { - eprintln!("{meta:#?}"); - if **meta != message && **meta != initial_message { - if meta.field("mtype").is_some() && meta.field("mlen").is_some() { - // If a message has mtype and mlen, it should subclass Message - assert_eq!(*meta.parent().unwrap(), message); - } else if meta.field("mlen").is_some() { - // If a message has mlen only, it should subclass InitialMessage - assert_eq!(*meta.parent().unwrap(), initial_message); - } - } - } - } -} diff --git a/rust/pgrust/src/protocol/edgedb.rs b/rust/pgrust/src/protocol/edgedb.rs index 98a25c2155b..da32a8d6367 100644 --- a/rust/pgrust/src/protocol/edgedb.rs +++ b/rust/pgrust/src/protocol/edgedb.rs @@ -1,5 +1,5 @@ -use super::gen::protocol; -use crate::protocol::message_group::message_group; +use db_proto::{message_group, protocol}; + message_group!( EdgeDBBackend: Message = [ AuthenticationOk, diff --git a/rust/pgrust/src/protocol/mod.rs b/rust/pgrust/src/protocol/mod.rs index 568272eabc6..c09f9c97449 100644 --- a/rust/pgrust/src/protocol/mod.rs +++ b/rust/pgrust/src/protocol/mod.rs @@ -1,168 +1,10 @@ -mod arrays; -mod buffer; -mod datatypes; pub mod edgedb; -mod gen; -mod message_group; pub mod postgres; -mod writer; - -/// Metatypes for the protocol and related arrays/strings. -pub mod meta { - pub use super::arrays::meta::*; - pub use super::datatypes::meta::*; -} - -#[allow(unused)] -pub use arrays::{Array, ArrayIter, ZTArray, ZTArrayIter}; -pub use buffer::StructBuffer; -#[allow(unused)] -pub use datatypes::{Encoded, LString, Rest, ZTString}; -pub use message_group::match_message; -pub use writer::BufWriter; - -#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq, Eq)] -pub enum ParseError { - #[error("Buffer is too short")] - TooShort, - #[error("Invalid data")] - InvalidData, -} - -/// Implemented for all structs. -pub trait StructMeta { - type Struct<'a>: std::fmt::Debug; - fn new(buf: &[u8]) -> Result, ParseError>; - fn to_vec(s: &Self::Struct<'_>) -> Vec; -} - -/// Implemented for all generated structs that have a [`meta::Length`] field at a fixed offset. -pub trait StructLength: StructMeta { - fn length_field_of(of: &Self::Struct<'_>) -> usize; - fn length_field_offset() -> usize; - fn length_of_buf(buf: &[u8]) -> Option { - if buf.len() < Self::length_field_offset() + std::mem::size_of::() { - None - } else { - let len = FieldAccess::::extract( - &buf[Self::length_field_offset() - ..Self::length_field_offset() + std::mem::size_of::()], - ) - .ok()?; - Some(Self::length_field_offset() + len) - } - } -} - -/// For a given metaclass, returns the inflated type, a measurement type and a -/// builder type. -pub trait Enliven { - type WithLifetime<'a>; - type ForMeasure<'a>: 'a; - type ForBuilder<'a>: 'a; -} - -pub trait FixedSize: Enliven { - const SIZE: usize; - /// Extract this type from the given buffer, assuming that enough bytes are available. - fn extract_infallible(buf: &[u8]) -> ::WithLifetime<'_>; -} - -#[derive(Debug, Eq, PartialEq)] -pub enum MetaRelation { - Parent, - Length, - Item, - Field(&'static str), -} - -pub trait Meta { - fn name(&self) -> &'static str { - std::any::type_name::() - } - fn relations(&self) -> &'static [(MetaRelation, &'static dyn Meta)] { - &[] - } - fn field(&self, name: &'static str) -> Option<&'static dyn Meta> { - for (relation, meta) in self.relations() { - if relation == &MetaRelation::Field(name) { - return Some(*meta); - } - } - None - } - fn parent(&self) -> Option<&'static dyn Meta> { - for (relation, meta) in self.relations() { - if relation == &MetaRelation::Parent { - return Some(*meta); - } - } - None - } -} - -impl PartialEq for dyn Meta { - fn eq(&self, other: &T) -> bool { - other.name() == self.name() - } -} - -impl std::fmt::Debug for dyn Meta { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut s = f.debug_struct(self.name()); - for (relation, meta) in self.relations() { - if relation == &MetaRelation::Parent { - s.field(&format!("{relation:?}"), &meta.name()); - } else { - s.field(&format!("{relation:?}"), meta); - } - } - s.finish() - } -} - -/// Delegates to a concrete [`FieldAccess`] but as a non-const trait. This is -/// used for performing extraction in iterators. -pub(crate) trait FieldAccessArray: Enliven { - const META: &'static dyn Meta; - fn size_of_field_at(buf: &[u8]) -> Result; - fn extract(buf: &[u8]) -> Result<::WithLifetime<'_>, ParseError>; -} - -/// This struct is specialized for each type we want to extract data from. We -/// have to do it this way to work around Rust's lack of const specialization. -pub(crate) struct FieldAccess { - _phantom_data: std::marker::PhantomData, -} - -/// Delegate to the concrete [`FieldAccess`] for each type we want to extract. -macro_rules! field_access { - ($ty:ty) => { - impl $crate::protocol::FieldAccessArray for $ty { - const META: &'static dyn $crate::protocol::Meta = - $crate::protocol::FieldAccess::<$ty>::meta(); - #[inline(always)] - fn size_of_field_at(buf: &[u8]) -> Result { - $crate::protocol::FieldAccess::<$ty>::size_of_field_at(buf) - } - #[inline(always)] - fn extract( - buf: &[u8], - ) -> Result< - ::WithLifetime<'_>, - $crate::protocol::ParseError, - > { - $crate::protocol::FieldAccess::<$ty>::extract(buf) - } - } - }; -} -pub(crate) use field_access; #[cfg(test)] mod tests { use super::*; - use buffer::StructBuffer; + use db_proto::{match_message, Encoded, StructBuffer, StructMeta}; use postgres::{builder, data::*, measure, meta}; use rand::Rng; /// We want to ensure that no malformed messages will cause unexpected diff --git a/rust/pgrust/src/protocol/postgres.rs b/rust/pgrust/src/protocol/postgres.rs index df313ef106a..bccc5eb126d 100644 --- a/rust/pgrust/src/protocol/postgres.rs +++ b/rust/pgrust/src/protocol/postgres.rs @@ -1,5 +1,4 @@ -use super::gen::protocol; -use super::message_group::message_group; +use db_proto::{message_group, protocol}; message_group!( /// The `Backend` message group contains messages sent from the backend to the frontend. diff --git a/rust/pgrust/src/python.rs b/rust/pgrust/src/python.rs index 0914e181367..6cb61da2bc1 100644 --- a/rust/pgrust/src/python.rs +++ b/rust/pgrust/src/python.rs @@ -11,11 +11,9 @@ use crate::{ }, ConnectionSslRequirement, }, - protocol::{ - postgres::{data::SSLResponse, meta, FrontendBuilder, InitialBuilder}, - StructBuffer, - }, + protocol::postgres::{data::SSLResponse, meta, FrontendBuilder, InitialBuilder}, }; +use db_proto::StructBuffer; use pyo3::{ buffer::PyBuffer, exceptions::{PyException, PyRuntimeError}, @@ -50,12 +48,6 @@ impl From for PyErr { } } -impl From for PyErr { - fn from(err: crate::protocol::ParseError) -> PyErr { - PyRuntimeError::new_err(err.to_string()) - } -} - impl EnvVar for (String, Bound<'_, PyAny>) { fn read(&self, name: &'static str) -> Option> { // os.environ[name], or the default user if not @@ -359,7 +351,8 @@ impl PyConnectionState { if self.inner.read_ssl_response() { // SSL responses are always one character let response = [buffer.as_slice(py).unwrap().first().unwrap().get()]; - let response = SSLResponse::new(&response)?; + let response = + SSLResponse::new(&response).map_err(|e| PyException::new_err(e.to_string()))?; self.inner .drive(ConnectionDrive::SslResponse(response), &mut self.update)?; } else { diff --git a/rust/pgrust/tests/query_real_postgres.rs b/rust/pgrust/tests/query_real_postgres.rs index f8dde14716d..ce6a9bfa41f 100644 --- a/rust/pgrust/tests/query_real_postgres.rs +++ b/rust/pgrust/tests/query_real_postgres.rs @@ -4,13 +4,13 @@ use std::num::NonZero; use std::rc::Rc; // Constants +use db_proto::match_message; use gel_auth::AuthType; use pgrust::connection::tokio::TokioStream; use pgrust::connection::{ Client, Credentials, FlowAccumulator, MaxRows, Oid, Param, PipelineBuilder, Portal, ResolvedTarget, Statement, }; -use pgrust::protocol::match_message; use pgrust::protocol::postgres::data::*; use tokio::task::LocalSet;