diff --git a/Cargo.toml b/Cargo.toml index 4dcca311af82..501874ecf307 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ members = [ resolver = "2" [workspace.dependencies] -pyo3 = { version = "0.23.1", features = ["extension-module", "serde", "macros"] } +pyo3 = { version = "0.23.1", features = ["serde", "macros"] } tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "time", "sync", "net", "io-util"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] } diff --git a/rust/auth/src/lib.rs b/rust/auth/src/lib.rs index b4b51a700d30..4310b15e060a 100644 --- a/rust/auth/src/lib.rs +++ b/rust/auth/src/lib.rs @@ -34,7 +34,7 @@ pub enum AuthType { ScramSha256, } -#[derive(Debug, Clone)] +#[derive(derive_more::Debug, Clone)] pub enum CredentialData { /// A credential that always succeeds, regardless of input password. Due to /// the design of SCRAM-SHA-256, this cannot be used with that auth type. @@ -42,10 +42,13 @@ pub enum CredentialData { /// A credential that always fails, regardless of the input password. Deny, /// A plain-text password. + #[debug("Plain(...)")] Plain(String), /// A stored MD5 hash + salt. + #[debug("Md5(...)")] Md5(md5::StoredHash), /// A stored SCRAM-SHA-256 key. + #[debug("Scram(...)")] Scram(scram::StoredKey), } diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index 131e28e429b6..71e8255463d4 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2018" [features] -python_extension = ["pyo3/extension-module", "pyo3/serde"] +python_extension = ["pyo3/serde"] [dependencies] pyo3.workspace = true @@ -38,5 +38,6 @@ tracing-subscriber = "0" [dev-dependencies] rstest = "0.22.0" test-log = { version = "0", features = ["trace"] } +pyo3 = { workspace = true } [lib] diff --git a/rust/frontend/examples/smoketest.rs b/rust/frontend/examples/smoketest.rs new file mode 100644 index 000000000000..914a470dbbad --- /dev/null +++ b/rust/frontend/examples/smoketest.rs @@ -0,0 +1,247 @@ +use std::{cell::RefCell, collections::HashMap, future::Future, rc::Rc}; + +use gel_auth::CredentialData; +use openssl::ssl::{Ssl, SslContext, SslMethod}; +use pgrust::{connection::{Client, Credentials}, protocol::{edgedb::data::{CommandComplete, ParameterStatus, StateDataDescription}, StructBuffer}}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, net::TcpSocket, task::{self, LocalSet} +}; +use std::pin::{Pin, pin}; + +#[derive(Debug, Clone)] +struct TestSetup { + addr: std::net::SocketAddr, + username: String, + password: String, + database: String, +} + +trait SmokeTest { + fn name(&self) -> String; + async fn run(&self, setup: &TestSetup) -> Result<(), Box>; +} + +struct PostgresSelect { + query: String, + expected: String, +} + +impl SmokeTest for PostgresSelect { + fn name(&self) -> String { + format!("PostgresSelect [{}]", self.query) + } + + async fn run(&self, setup: &TestSetup) -> Result<(), Box> { + use pgrust::protocol::postgres::data::{DataRow, ErrorResponse, RowDescription}; + let mut ssl = SslContext::builder(SslMethod::tls_client())?.build(); + let mut ssl = Ssl::new(&ssl)?; + ssl.set_connect_state(); + + let socket = TcpSocket::new_v4()?.connect(setup.addr).await?; + + let credentials = Credentials { + username: setup.username.clone(), + password: setup.password.clone(), + database: setup.database.clone(), + server_settings: HashMap::new(), + }; + let (client, task) = Client::new(credentials, socket, ssl); + tokio::task::spawn_local(task); + client.ready().await?; + + let mut out = Rc::new(RefCell::new(String::new())); + let out_clone = out.clone(); + client + .query( + &self.query, + ( + move |rows: RowDescription<'_>| { + let cols = rows.fields().into_iter().map(|field| field.name().to_string_lossy().to_string()).collect::>(); + out.borrow_mut().push_str(&format!("{}\n", cols.join(","))); + let out = out.clone(); + move |row: Result, ErrorResponse<'_>>| { + let Ok(row) = row else { + return; + }; + let values: Vec<_> = row.values().into_iter().map(|v| v.to_string_lossy().to_string()).collect(); + out.borrow_mut().push_str(&format!("{}\n", values.join(","))); + } + }, + |_: ErrorResponse<'_>| {}, + ), + ) + .await?; + + let out = out_clone.borrow().clone(); + if out == self.expected { + Ok(()) + } else { + Err(format!( + "Expected `{}` but got `{}`", + self.expected, + out + ) + .into()) + } + } +} + +struct EdgeQLSelect { + query: String, + expected: String, +} + +impl SmokeTest for EdgeQLSelect { + fn name(&self) -> String { + format!("EdgeQLSelect [{}]", self.query) + } + + async fn run(&self, setup: &TestSetup) -> Result<(), Box> { + use pgrust::protocol::edgedb::{data::{Message, ClientHandshake, Data, ServerHandshake}, builder, meta}; + + let socket = TcpSocket::new_v4()?.connect(setup.addr).await?; + let mut ssl = SslContext::builder(SslMethod::tls_client())?; + ssl.set_alpn_protos(b"\x0dedgedb-binary")?; + let ssl = ssl.build(); + let mut ssl = Ssl::new(&ssl)?; + ssl.set_connect_state(); + + let mut stream = tokio_openssl::SslStream::new(ssl, socket)?; + Pin::new(&mut stream).do_handshake().await?; + + let handshake = builder::ClientHandshake { + major_ver: 2, + minor_ver: 0, + params: &[ + builder::ConnectionParam { + name: "user", + value: &setup.username, + }, + builder::ConnectionParam { + name: "database", + value: &setup.database, + }, + ], + extensions: &[], + }; + stream.write_all(&handshake.to_vec()).await?; + + let execute = builder::Execute { + command_text: &self.query, + output_format: b'j', + expected_cardinality: b'o', // AT_MOST_ONE + ..Default::default() + }; + stream.write_all(&execute.to_vec()).await?; + + let mut buf = StructBuffer::::default(); + + let mut done = false; + while !done { + let mut bytes = vec![0; 1024]; + let n = stream.read(&mut bytes).await?; + if n == 0 { + break; + } + buf.push(&bytes[..n], |msg| { + match msg { + Ok(msg) => { + if let Some(msg) = StateDataDescription::try_new(&msg) { + eprintln!("{:?}", String::from_utf8_lossy(msg.typedesc().as_ref())); + } else if let Some(msg) = ParameterStatus::try_new(&msg) { + eprintln!("{:?} {:?}", String::from_utf8_lossy(msg.name().as_ref()), String::from_utf8_lossy(msg.value().as_ref())); + } else if let Some(data) = Data::try_new(&msg) { + for data in data.data() { + eprintln!("{:?}", data.data()); + } + } else if let Some(_) = CommandComplete::try_new(&msg) { + done = true; + return; + } else { + eprintln!("{} {:?}", msg.mtype() as char, msg); + } + } + Err(e) => { + eprintln!("Error: {}", e); + } + } + }); + } + + Ok(()) + } +} + +#[tokio::main] +pub async fn main() { + tracing_subscriber::fmt::init(); + + let args: Vec = std::env::args().collect(); + if args.len() != 5 { + println!( + "Usage: {} ", + args[0] + ); + return; + } + + let addr = &args[1]; + let username = &args[2]; + let password = &args[3]; + let database = &args[4]; + + let addr = match addr.parse::() { + Ok(addr) => addr, + Err(e) => { + eprintln!("Invalid address format: {}", e); + return; + } + }; + + let setup = TestSetup { + addr, + username: username.to_string(), + password: password.to_string(), + database: database.to_string(), + }; + + LocalSet::new() + .run_until(async { + let mut tests: Vec + 'static>>> = vec![]; + + fn test(setup: &TestSetup, test: impl SmokeTest + 'static) -> Pin + 'static>> { + let setup = setup.clone(); + Box::pin(async move { + let name = test.name(); + let res = test.run(&setup).await; + match res { + Ok(_) => println!("✅ {name} passed"), + Err(e) => println!("❌ {name} failed: {}", e), + }; + }) + } + + tests.push(test(&setup, PostgresSelect { + query: "SELECT".to_string(), + expected: "\n\n".to_string(), + })); + tests.push(test(&setup, PostgresSelect { + query: "SELECT 1 as x".to_string(), + expected: "x\n1\n".to_string(), + })); + tests.push(test(&setup, PostgresSelect { + query: "SELECT LIMIT 0".to_string(), + expected: "\n".to_string(), + })); + tests.push(test(&setup, EdgeQLSelect { + query: "select 1".to_string(), + expected: "1\n".to_string(), + })); + + for test in tests { + test.await; + } + + }) + .await; +} diff --git a/rust/frontend/src/listener.rs b/rust/frontend/src/listener.rs index 8fbc2d82e098..e1891794713e 100644 --- a/rust/frontend/src/listener.rs +++ b/rust/frontend/src/listener.rs @@ -369,7 +369,8 @@ async fn handle_stream_edgedb_binary( trace!("UPDATE: {update:?}"); match update { Auth(user, database, branch) => { - identity.set_branch(BranchDB::Branch(database)); + identity.set_branch(branch); + identity.set_database(database); identity.set_user(user); auth_ready.store(true, Ordering::SeqCst); } @@ -592,7 +593,7 @@ async fn handle_stream_postgres_initial( trace!("UPDATE: {update:?}"); match update { Auth(user, database) => { - identity.set_branch(BranchDB::Branch(database)); + identity.set_pg_database(database); identity.set_user(user); auth_ready.store(true, Ordering::SeqCst); } @@ -1013,7 +1014,7 @@ mod tests { target: AuthTarget, ) -> impl Future> { self.log(format!("lookup_auth: {:?}", identity)); - async { Ok(CredentialData::Deny) } + async { Ok(CredentialData::Trust) } } fn accept_stream( @@ -1120,14 +1121,14 @@ mod tests { value: "name", }, StartupNameValue { - name: "username", + name: "user", value: "me", }, ], } .to_vec(); stm.write_all(&msg).await.unwrap(); - assert_eq!(stm.read_u8().await.unwrap(), b'S'); + assert_eq!(stm.read_u8().await.unwrap(), b'R'); // AuthenticationOk Ok(()) }); } diff --git a/rust/frontend/src/python.rs b/rust/frontend/src/python.rs index 8b137891791f..a9f081003b57 100644 --- a/rust/frontend/src/python.rs +++ b/rust/frontend/src/python.rs @@ -1 +1,16 @@ +#[cfg(test)] +mod tests { + use pyo3::{types::PyAnyMethods, Python}; + use super::*; + + #[test] + fn test_python_extension() { + pyo3::prepare_freethreaded_python(); + Python::with_gil(|py| { + let sys = py.import("sys").unwrap(); + let version = sys.getattr("version").unwrap(); + println!("Python version: {}", version); + }); + } +} diff --git a/rust/frontend/src/service.rs b/rust/frontend/src/service.rs index fae3b4fcabf4..43052406d64c 100644 --- a/rust/frontend/src/service.rs +++ b/rust/frontend/src/service.rs @@ -23,10 +23,10 @@ pub enum AuthTarget { pub enum BranchDB { /// Branch only. Branch(String), - /// DB only (legacy). + /// Database name (legacy). DB(String), - /// Branch and DB (advanced). - BranchDB(String, String), + /// Postgres database name. + PGDB(String), } #[derive(thiserror::Error, Debug)] @@ -58,8 +58,28 @@ impl ConnectionIdentityBuilder { self } - pub fn set_branch(&self, branch: BranchDB) -> &Self { - *self.db.lock().unwrap() = Some(branch); + pub fn set_database(&self, database: String) -> &Self { + if !database.is_empty() { + // Only set if currently non-empty + let mut db = self.db.lock().unwrap(); + if db.is_none() { + *db = Some(BranchDB::DB(database)); + } + } + self + } + + pub fn set_branch(&self, branch: String) -> &Self { + if !branch.is_empty() { + *self.db.lock().unwrap() = Some(BranchDB::Branch(branch)); + } + self + } + + pub fn set_pg_database(&self, database: String) -> &Self { + if !database.is_empty() { + *self.db.lock().unwrap() = Some(BranchDB::PGDB(database)); + } self } diff --git a/rust/pgrust/src/connection/mod.rs b/rust/pgrust/src/connection/mod.rs index c0a7e9297593..5c8f011ea129 100644 --- a/rust/pgrust/src/connection/mod.rs +++ b/rust/pgrust/src/connection/mod.rs @@ -91,7 +91,6 @@ pub enum SslError { #[derive(Clone, Default, derive_more::Debug)] pub struct Credentials { pub username: String, - #[debug(skip)] pub password: String, pub database: String, pub server_settings: HashMap, diff --git a/rust/pgrust/src/handshake/mod.rs b/rust/pgrust/src/handshake/mod.rs index a253e3bbde67..bc3689b5c206 100644 --- a/rust/pgrust/src/handshake/mod.rs +++ b/rust/pgrust/src/handshake/mod.rs @@ -139,7 +139,7 @@ mod tests { username: "user".to_string(), password: "password".to_string(), database: "database".to_string(), - ..Default::default() + server_settings: Default::default(), }, ConnectionSslRequirement::Disable, );