Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Nov 26, 2024
1 parent bf7e0cf commit 0dba69b
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 40 deletions.
6 changes: 2 additions & 4 deletions edb/server/frontend/examples/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use edb_frontend::listener::*;
use edb_frontend::service::*;
use edb_frontend::stream::*;
use hyper::Response;
use pgrust::auth::AuthType;
use pgrust::auth::CredentialData;
use pgrust::auth::StoredHash;
use tokio::io::AsyncReadExt;
Expand All @@ -21,10 +22,7 @@ impl BabelfishService for ExampleService {
) -> impl Future<Output = Result<CredentialData, std::io::Error>> {
eprintln!("lookup_auth: {:?}", identity);
async move {
Ok(CredentialData::Md5(StoredHash::generate(
b"password",
&identity.user,
)))
Ok(CredentialData::new(AuthType::Trust, "matt".to_owned(), "password".to_owned()))
}
}

Expand Down
124 changes: 102 additions & 22 deletions edb/server/frontend/src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@ use crate::{
use futures::StreamExt;
use hyper::{upgrade::OnUpgrade, Request, Response, StatusCode};
use openssl::ssl::{AlpnError, NameType, SniError, Ssl, SslAlert, SslContext, SslMethod};
use pgrust::{
errors::{PgError, PgErrorConnectionException, PgErrorInvalidAuthorizationSpecification},
handshake::{
server::{ConnectionDrive, ConnectionEvent, ServerState},
ConnectionSslRequirement,
},
};
use scopeguard::defer;
use std::sync::{
atomic::{AtomicBool, Ordering},
Expand Down Expand Up @@ -360,18 +353,102 @@ async fn handle_stream_edgedb_binary(
identity: ConnectionIdentityBuilder,
bound_config: impl IsBoundConfig,
) -> Result<(), std::io::Error> {
socket.read_u8().await?;
let mut length_bytes = [0; 4];
socket.read_exact(&mut length_bytes).await?;
let length = u32::from_be_bytes(length_bytes) - 4;
let mut handshake = vec![0; length as usize];
socket.read_exact(&mut handshake).await?;
println!("Handshake:\n{:?}", hexdump::hexdump(&handshake));
_ = socket
.write_all(StreamType::EdgeDBBinary.go_away_message())
use pgrust::{
errors::edgedb::{EdbError},
handshake::{
edgedb_server::{ConnectionDrive, ConnectionEvent, ServerState},
},
};

let mut resolved_identity = None;
let mut server_state = ServerState::new();
let mut startup_params = HashMap::with_capacity(16);
let mut send_buf = Mutex::new(bytes::BytesMut::new());
let auth_ready = AtomicBool::new(false);
let params_ready = AtomicBool::new(false);
let mut update = |update: ConnectionEvent<'_>| {
use ConnectionEvent::*;
trace!("UPDATE: {update:?}");
match update {
Auth(user, database) => {
identity.set_branch(BranchDB::Branch(database));
identity.set_user(user);
auth_ready.store(true, Ordering::SeqCst);
}
Parameter(name, value) => {
startup_params.insert(name.to_owned(), value.to_owned());
}
Params => params_ready.store(true, Ordering::SeqCst),
Send(bytes) => {
// TODO: Reduce copies and allocations here
send_buf.lock().unwrap().extend_from_slice(&bytes.to_vec());
}
ServerError(e) => {
trace!("ERROR {e:?}");
}
StateChanged(..) => {}
}
Ok(())
};

while !server_state.is_done() || !send_buf.lock().unwrap().is_empty() {
let mut send_buf = std::mem::take(&mut *send_buf.lock().unwrap());
if !send_buf.is_empty() {
eprintln!("Sending {send_buf:?}");
socket.write_all(&send_buf).await?;
} else if auth_ready.swap(false, Ordering::SeqCst) {
let built = match identity.clone().build() {
Ok(built) => built,
Err(e) => {
server_state.drive(ConnectionDrive::Fail(EdbError::AuthenticationError, "Missing database or user"), &mut update).unwrap();
return Ok(());
}
};
resolved_identity = Some(built);
let auth = bound_config
.service()
.lookup_auth(
resolved_identity.clone().unwrap(),
AuthTarget::Stream(StreamLanguage::Postgres),
)
.await?;
server_state
.drive(
ConnectionDrive::AuthInfo(auth.auth_type(), auth),
&mut update,
)
.unwrap();
} else if params_ready.swap(false, Ordering::SeqCst) {
server_state
.drive(ConnectionDrive::Ready(Default::default()), &mut update)
.unwrap();
} else {
let mut b = [0; 512];
let n = socket.read(&mut b).await?;
if n == 0 {
// EOF
return Ok(());
}
let res = server_state.drive(ConnectionDrive::RawMessage(&b[..n]), &mut update);
if res.is_err() {
// TODO?
error!("{res:?}");
return Ok(());
}
}
}

let socket = socket.upgrade(StreamPropertiesBuilder {
stream_params: Some(startup_params),
..Default::default()
});
bound_config
.service()
.accept_stream(resolved_identity.unwrap(), StreamLanguage::EdgeDB, socket)
.await;
_ = socket.shutdown().await;

Ok(())

}

async fn handle_stream_http1x(
Expand Down Expand Up @@ -491,11 +568,14 @@ async fn handle_stream_postgres_initial(
identity: ConnectionIdentityBuilder,
bound_config: impl IsBoundConfig,
) -> Result<(), std::io::Error> {
use pgrust::protocol::{
match_message, messages::Initial, meta::InitialMessage, StartupMessage, StructBuffer,
use pgrust::{
errors::{PgError, PgErrorInvalidAuthorizationSpecification},
handshake::{
server::{ConnectionDrive, ConnectionEvent, ServerState},
ConnectionSslRequirement,
},
};

// We'll handle SSL upgrades here
let mut resolved_identity = None;
let mut server_state = ServerState::new(ConnectionSslRequirement::Disable);
let mut startup_params = HashMap::with_capacity(16);
Expand Down Expand Up @@ -529,7 +609,7 @@ async fn handle_stream_postgres_initial(
Ok(())
};

while !server_state.is_done() {
while !server_state.is_done() || !send_buf.lock().unwrap().is_empty() {
let mut send_buf = std::mem::take(&mut *send_buf.lock().unwrap());
if !send_buf.is_empty() {
eprintln!("Sending {send_buf:?}");
Expand Down Expand Up @@ -1026,7 +1106,7 @@ mod tests {

#[test]
fn test_raw_postgres() {
use pgrust::protocol::builder::{StartupMessage, StartupNameValue};
use pgrust::protocol::postgres::builder::{StartupMessage, StartupNameValue};
run_test_service(TestMode::Tcp, |mut stm| async move {
let msg = StartupMessage {
params: &[
Expand Down
4 changes: 4 additions & 0 deletions edb/server/frontend/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ pub enum AuthTarget {

#[derive(Clone, Debug)]
pub enum BranchDB {
/// Branch only.
Branch(String),
/// DB only (legacy).
DB(String),
/// Branch and DB (advanced).
BranchDB(String, String),
}

#[derive(thiserror::Error, Debug)]
Expand Down
53 changes: 49 additions & 4 deletions edb/server/pgrust/src/handshake/edgedb_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ impl ServerState {
trace!("SERVER DRIVE: {:?} {:?}", self.state, drive);
let res = match drive {
ConnectionDrive::RawMessage(raw) => self.buffer.push_fallible(raw, |message| {
trace!("Parsed message: {message:?}");
self.state
.drive_inner(ConnectionDrive::Message(message), update)
}),
Expand Down Expand Up @@ -192,11 +193,55 @@ impl ServerStateImpl {
(Initial, ConnectionDrive::Message(message)) => {
match_message!(message, Message {
(ClientHandshake as handshake) => {
trace!("ClientHandshake: {handshake:?}");

// The handshake should generate an event rather than hardcoding the min/max protocol versions.

// No extensions are supported
if !handshake.extensions().is_empty() {
update.send(EdgeDBBackendBuilder::ServerHandshake(builder::ServerHandshake { major_ver: 2, minor_ver: 0, extensions: &[] }))?;
return Ok(());
}

// We support 1.x and 2.0
let major_ver = handshake.major_ver();
let minor_ver = handshake.minor_ver();
// TODO: Check version compatibility
*self = AuthInfo(String::new()); // No user info in EdgeDB
update.auth(String::new(), String::new())?;
match (major_ver, minor_ver) {
(..=0, _) => {
update.send(EdgeDBBackendBuilder::ServerHandshake(builder::ServerHandshake { major_ver: 1, minor_ver: 0, extensions: &[] }))?;
return Ok(());
}
(1, 1..) => {
// 1.(1+) never existed
return Err(PROTOCOL_VERSION_ERROR);
}
(2, 1..) | (3.., _) => {
update.send(EdgeDBBackendBuilder::ServerHandshake(builder::ServerHandshake { major_ver: 2, minor_ver: 0, extensions: &[] }))?;
return Ok(());
}
_ => {}
}

let mut user = String::new();
let mut database = String::new();
let mut branch = String::new();
for param in handshake.params() {
match param.name().to_str()? {
"user" => user = param.value().to_owned()?,
"database" => database = param.value().to_owned()?,
"branch" => branch = param.value().to_owned()?,
_ => {}
}
update.parameter(param.name().to_str()?, param.value().to_str()?);
}
if user.is_empty() {
return Err(AUTH_ERROR.into());
}
if database.is_empty() {
database = user.clone();
}
*self = AuthInfo(user.clone());
update.auth(user, database)?;
},
unknown => {
log_unknown_message(unknown, "Initial")?;
Expand Down Expand Up @@ -271,7 +316,7 @@ impl ServerStateImpl {
update.send(EdgeDBBackendBuilder::ReadyForCommand(
builder::ReadyForCommand {
annotations: &[],
transaction_state: b'I',
transaction_state: 0x49,
},
))?;
*self = Ready;
Expand Down
2 changes: 1 addition & 1 deletion edb/server/pgrust/src/handshake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub enum ConnectionSslRequirement {
}

mod client_state_machine;
mod edgedb_server;
pub mod edgedb_server;
mod server_auth;
mod server_state_machine;

Expand Down
4 changes: 2 additions & 2 deletions edb/server/pgrust/src/protocol/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,13 @@ impl FieldAccess<LStringMeta> {
}
#[inline(always)]
pub fn copy_to_buf(buf: &mut BufWriter, value: &str) {
let len = value.len();
let len = value.len() as u32;
buf.write(&len.to_be_bytes());
buf.write(value.as_bytes());
}
#[inline(always)]
pub fn copy_to_buf_ref(buf: &mut BufWriter, value: &str) {
let len = value.len();
let len = value.len() as u32;
buf.write(&len.to_be_bytes());
buf.write(value.as_bytes());
}
Expand Down
2 changes: 1 addition & 1 deletion edb/server/pgrust/src/protocol/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ macro_rules! protocol {
($( $( #[ $sdoc:meta ] )* struct $name:ident $(: $super:ident)? { $($struct:tt)+ } )+) => {
$(
paste::paste!(
#[allow(unused_imports)]
pub(crate) mod [<__ $name:lower>] {
#[allow(unused_imports)]
use super::meta::*;
use $crate::protocol::meta::*;
use $crate::protocol::gen::*;
Expand Down
11 changes: 11 additions & 0 deletions edb/server/pgrust/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,4 +609,15 @@ mod tests {

fuzz_test::<meta::FunctionCall>(message);
}

#[test]
fn test_edgedb_sasl() {
use crate::protocol::edgedb::*;

assert_eq!(builder::AuthenticationRequiredSASLMessage {
methods: &["SCRAM-SHA-256"]
}.to_vec(), vec![82, 0, 0, 0, 33, 0, 0, 0, 10, 0, 0, 0, 1, 0, 0, 0, 13, 83, 67, 82, 65, 77, 45, 83, 72, 65, 45, 50, 53, 54]);


}
}
1 change: 0 additions & 1 deletion edb/server/pgrust/src/protocol/postgres.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use super::gen::protocol;
use super::message_group::message_group;
use crate::protocol::meta::*;

message_group!(
/// The `Backend` message group contains messages sent from the backend to the frontend.
Expand Down
8 changes: 4 additions & 4 deletions edb/server/pgrust/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
},
ConnectionSslRequirement,
},
protocol::{meta, SSLResponse, StructBuffer},
protocol::{postgres::{data::SSLResponse, meta, FrontendBuilder, InitialBuilder}, StructBuffer},
};
use pyo3::{
buffer::PyBuffer,
Expand Down Expand Up @@ -405,7 +405,7 @@ struct PyConnectionStateUpdate {
impl ConnectionStateSend for PyConnectionStateUpdate {
fn send_initial(
&mut self,
message: crate::protocol::definition::InitialBuilder,
message: InitialBuilder,
) -> Result<(), std::io::Error> {
Python::with_gil(|py| {
let bytes = PyByteArray::new(py, &message.to_vec());
Expand All @@ -419,7 +419,7 @@ impl ConnectionStateSend for PyConnectionStateUpdate {

fn send(
&mut self,
message: crate::protocol::definition::FrontendBuilder,
message: FrontendBuilder,
) -> Result<(), std::io::Error> {
Python::with_gil(|py| {
let bytes = PyBytes::new(py, &message.to_vec());
Expand Down Expand Up @@ -476,7 +476,7 @@ impl ConnectionStateUpdate for PyConnectionStateUpdate {
});
}

fn auth(&mut self, auth: crate::handshake::AuthType) {
fn auth(&mut self, auth: crate::auth::AuthType) {
Python::with_gil(|py| {
if let Err(e) = self.py_update.call_method1(py, "auth", (auth as u8,)) {
eprintln!("Error in auth: {:?}", e);
Expand Down
3 changes: 2 additions & 1 deletion edb/server/pgrust/tests/real_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ fn run_postgres(
let content = std::fs::read_to_string(&pg_hba_path)?;
let modified_content = content
.lines()
.filter(|line| !line.starts_with("#") && !line.is_empty())
.map(|line| {
if line.trim_start().starts_with("host") {
line.replacen("host", "hostssl", 1)
Expand All @@ -202,7 +203,7 @@ fn run_postgres(
})
.collect::<Vec<String>>()
.join("\n");
eprintln!("pg_hba.conf:\n{modified_content}");
eprintln!("pg_hba.conf:\n==========\n{modified_content}\n==========");
std::fs::write(&pg_hba_path, modified_content)?;

command.arg("-l");
Expand Down

0 comments on commit 0dba69b

Please sign in to comment.