Skip to content

Commit

Permalink
wip: smoketest
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Jan 2, 2025
1 parent cde234c commit 19787c6
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ members = [
resolver = "2"

[workspace.dependencies]
pyo3 = { version = "0.23.1", features = ["extension-module", "serde", "macros"] }
pyo3 = { version = "0.23.1", features = ["serde", "macros"] }
tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "time", "sync", "net", "io-util"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] }
Expand Down
5 changes: 4 additions & 1 deletion rust/auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ pub enum AuthType {
ScramSha256,
}

#[derive(Debug, Clone)]
#[derive(derive_more::Debug, Clone)]
pub enum CredentialData {
/// A credential that always succeeds, regardless of input password. Due to
/// the design of SCRAM-SHA-256, this cannot be used with that auth type.
Trust,
/// A credential that always fails, regardless of the input password.
Deny,
/// A plain-text password.
#[debug("Plain(...)")]
Plain(String),
/// A stored MD5 hash + salt.
#[debug("Md5(...)")]
Md5(md5::StoredHash),
/// A stored SCRAM-SHA-256 key.
#[debug("Scram(...)")]
Scram(scram::StoredKey),
}

Expand Down
3 changes: 2 additions & 1 deletion rust/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2018"

[features]
python_extension = ["pyo3/extension-module", "pyo3/serde"]
python_extension = ["pyo3/serde"]

[dependencies]
pyo3.workspace = true
Expand Down Expand Up @@ -38,5 +38,6 @@ tracing-subscriber = "0"
[dev-dependencies]
rstest = "0.22.0"
test-log = { version = "0", features = ["trace"] }
pyo3 = { workspace = true }

[lib]
247 changes: 247 additions & 0 deletions rust/frontend/examples/smoketest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
use std::{cell::RefCell, collections::HashMap, future::Future, rc::Rc};

use gel_auth::CredentialData;
use openssl::ssl::{Ssl, SslContext, SslMethod};
use pgrust::{connection::{Client, Credentials}, protocol::{edgedb::data::{CommandComplete, ParameterStatus, StateDataDescription}, StructBuffer}};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, net::TcpSocket, task::{self, LocalSet}
};
use std::pin::{Pin, pin};

#[derive(Debug, Clone)]
struct TestSetup {
addr: std::net::SocketAddr,
username: String,
password: String,
database: String,
}

trait SmokeTest {
fn name(&self) -> String;
async fn run(&self, setup: &TestSetup) -> Result<(), Box<dyn std::error::Error>>;
}

struct PostgresSelect {
query: String,
expected: String,
}

impl SmokeTest for PostgresSelect {
fn name(&self) -> String {
format!("PostgresSelect [{}]", self.query)
}

async fn run(&self, setup: &TestSetup) -> Result<(), Box<dyn std::error::Error>> {
use pgrust::protocol::postgres::data::{DataRow, ErrorResponse, RowDescription};
let mut ssl = SslContext::builder(SslMethod::tls_client())?.build();
let mut ssl = Ssl::new(&ssl)?;
ssl.set_connect_state();

let socket = TcpSocket::new_v4()?.connect(setup.addr).await?;

let credentials = Credentials {
username: setup.username.clone(),
password: setup.password.clone(),
database: setup.database.clone(),
server_settings: HashMap::new(),
};
let (client, task) = Client::new(credentials, socket, ssl);
tokio::task::spawn_local(task);
client.ready().await?;

let mut out = Rc::new(RefCell::new(String::new()));
let out_clone = out.clone();
client
.query(
&self.query,
(
move |rows: RowDescription<'_>| {
let cols = rows.fields().into_iter().map(|field| field.name().to_string_lossy().to_string()).collect::<Vec<_>>();
out.borrow_mut().push_str(&format!("{}\n", cols.join(",")));
let out = out.clone();
move |row: Result<DataRow<'_>, ErrorResponse<'_>>| {
let Ok(row) = row else {
return;
};
let values: Vec<_> = row.values().into_iter().map(|v| v.to_string_lossy().to_string()).collect();
out.borrow_mut().push_str(&format!("{}\n", values.join(",")));
}
},
|_: ErrorResponse<'_>| {},
),
)
.await?;

let out = out_clone.borrow().clone();
if out == self.expected {
Ok(())
} else {
Err(format!(
"Expected `{}` but got `{}`",
self.expected,
out
)
.into())
}
}
}

struct EdgeQLSelect {
query: String,
expected: String,
}

impl SmokeTest for EdgeQLSelect {
fn name(&self) -> String {
format!("EdgeQLSelect [{}]", self.query)
}

async fn run(&self, setup: &TestSetup) -> Result<(), Box<dyn std::error::Error>> {
use pgrust::protocol::edgedb::{data::{Message, ClientHandshake, Data, ServerHandshake}, builder, meta};

let socket = TcpSocket::new_v4()?.connect(setup.addr).await?;
let mut ssl = SslContext::builder(SslMethod::tls_client())?;
ssl.set_alpn_protos(b"\x0dedgedb-binary")?;
let ssl = ssl.build();
let mut ssl = Ssl::new(&ssl)?;
ssl.set_connect_state();

let mut stream = tokio_openssl::SslStream::new(ssl, socket)?;
Pin::new(&mut stream).do_handshake().await?;

let handshake = builder::ClientHandshake {
major_ver: 2,
minor_ver: 0,
params: &[
builder::ConnectionParam {
name: "user",
value: &setup.username,
},
builder::ConnectionParam {
name: "database",
value: &setup.database,
},
],
extensions: &[],
};
stream.write_all(&handshake.to_vec()).await?;

let execute = builder::Execute {
command_text: &self.query,
output_format: b'j',
expected_cardinality: b'o', // AT_MOST_ONE
..Default::default()
};
stream.write_all(&execute.to_vec()).await?;

let mut buf = StructBuffer::<meta::Message>::default();

let mut done = false;
while !done {
let mut bytes = vec![0; 1024];
let n = stream.read(&mut bytes).await?;
if n == 0 {
break;
}
buf.push(&bytes[..n], |msg| {
match msg {
Ok(msg) => {
if let Some(msg) = StateDataDescription::try_new(&msg) {
eprintln!("{:?}", String::from_utf8_lossy(msg.typedesc().as_ref()));
} else if let Some(msg) = ParameterStatus::try_new(&msg) {
eprintln!("{:?} {:?}", String::from_utf8_lossy(msg.name().as_ref()), String::from_utf8_lossy(msg.value().as_ref()));
} else if let Some(data) = Data::try_new(&msg) {
for data in data.data() {
eprintln!("{:?}", data.data());
}
} else if let Some(_) = CommandComplete::try_new(&msg) {
done = true;
return;
} else {
eprintln!("{} {:?}", msg.mtype() as char, msg);
}
}
Err(e) => {
eprintln!("Error: {}", e);
}
}
});
}

Ok(())
}
}

#[tokio::main]
pub async fn main() {
tracing_subscriber::fmt::init();

let args: Vec<String> = std::env::args().collect();
if args.len() != 5 {
println!(
"Usage: {} <addr:port> <username> <password> <database>",
args[0]
);
return;
}

let addr = &args[1];
let username = &args[2];
let password = &args[3];
let database = &args[4];

let addr = match addr.parse::<std::net::SocketAddr>() {
Ok(addr) => addr,
Err(e) => {
eprintln!("Invalid address format: {}", e);
return;
}
};

let setup = TestSetup {
addr,
username: username.to_string(),
password: password.to_string(),
database: database.to_string(),
};

LocalSet::new()
.run_until(async {
let mut tests: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>> = vec![];

fn test(setup: &TestSetup, test: impl SmokeTest + 'static) -> Pin<Box<dyn Future<Output = ()> + 'static>> {
let setup = setup.clone();
Box::pin(async move {
let name = test.name();
let res = test.run(&setup).await;
match res {
Ok(_) => println!("✅ {name} passed"),
Err(e) => println!("❌ {name} failed: {}", e),
};
})
}

tests.push(test(&setup, PostgresSelect {
query: "SELECT".to_string(),
expected: "\n\n".to_string(),
}));
tests.push(test(&setup, PostgresSelect {
query: "SELECT 1 as x".to_string(),
expected: "x\n1\n".to_string(),
}));
tests.push(test(&setup, PostgresSelect {
query: "SELECT LIMIT 0".to_string(),
expected: "\n".to_string(),
}));
tests.push(test(&setup, EdgeQLSelect {
query: "select 1".to_string(),
expected: "1\n".to_string(),
}));

for test in tests {
test.await;
}

})
.await;
}
11 changes: 6 additions & 5 deletions rust/frontend/src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ async fn handle_stream_edgedb_binary(
trace!("UPDATE: {update:?}");
match update {
Auth(user, database, branch) => {
identity.set_branch(BranchDB::Branch(database));
identity.set_branch(branch);
identity.set_database(database);
identity.set_user(user);
auth_ready.store(true, Ordering::SeqCst);
}
Expand Down Expand Up @@ -592,7 +593,7 @@ async fn handle_stream_postgres_initial(
trace!("UPDATE: {update:?}");
match update {
Auth(user, database) => {
identity.set_branch(BranchDB::Branch(database));
identity.set_pg_database(database);
identity.set_user(user);
auth_ready.store(true, Ordering::SeqCst);
}
Expand Down Expand Up @@ -1013,7 +1014,7 @@ mod tests {
target: AuthTarget,
) -> impl Future<Output = Result<CredentialData, std::io::Error>> {
self.log(format!("lookup_auth: {:?}", identity));
async { Ok(CredentialData::Deny) }
async { Ok(CredentialData::Trust) }
}

fn accept_stream(
Expand Down Expand Up @@ -1120,14 +1121,14 @@ mod tests {
value: "name",
},
StartupNameValue {
name: "username",
name: "user",
value: "me",
},
],
}
.to_vec();
stm.write_all(&msg).await.unwrap();
assert_eq!(stm.read_u8().await.unwrap(), b'S');
assert_eq!(stm.read_u8().await.unwrap(), b'R'); // AuthenticationOk
Ok(())
});
}
Expand Down
15 changes: 15 additions & 0 deletions rust/frontend/src/python.rs
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
#[cfg(test)]
mod tests {
use pyo3::{types::PyAnyMethods, Python};

use super::*;

#[test]
fn test_python_extension() {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
let sys = py.import("sys").unwrap();
let version = sys.getattr("version").unwrap();
println!("Python version: {}", version);
});
}
}
Loading

0 comments on commit 19787c6

Please sign in to comment.