Skip to content

Commit

Permalink
Polish
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Oct 8, 2024
1 parent 269d8cd commit 65208d9
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 35 deletions.
37 changes: 18 additions & 19 deletions edb/server/pgrust/src/errors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use itertools::Itertools;
use paste::paste;
use std::{collections::HashMap, str::FromStr};
use tracing::warn;

use crate::protocol::ErrorResponse;

Expand Down Expand Up @@ -276,14 +274,14 @@ pub struct PgServerError {
pub code: PgError,
pub severity: PgErrorSeverity,
pub message: String,
pub extra: HashMap<ServerErrorField, String>,
pub extra: HashMap<PgServerErrorField, String>,
}

impl PgServerError {
pub fn new(
code: PgError,
arg: impl AsRef<str>,
extra: HashMap<ServerErrorField, String>,
extra: HashMap<PgServerErrorField, String>,
) -> Self {
Self {
code,
Expand All @@ -293,17 +291,18 @@ impl PgServerError {
}
}

pub fn fields(&self) -> impl Iterator<Item = (ServerErrorField, &str)> {
/// Iterate all the fields of this error.
pub fn fields(&self) -> impl Iterator<Item = (PgServerErrorField, &str)> {
let fields = [
(PgServerErrorField::Code, self.code.get_error_string()),
(PgServerErrorField::Message, self.message.as_str()),
(
PgServerErrorField::SeverityNonLocalized,
self.severity.as_ref(),
),
];
Iterator::chain(
[
(ServerErrorField::Code, self.code.get_error_string()),
(ServerErrorField::Message, self.message.as_str()),
(
ServerErrorField::SeverityNonLocalized,
self.severity.as_ref(),
),
]
.into_iter(),
fields.into_iter(),
self.extra.iter().map(|(f, e)| (*f, e.as_str())),
)
}
Expand All @@ -330,10 +329,10 @@ impl From<ErrorResponse<'_>> for PgServerError {

for field in error.fields() {
let value = field.value().to_string_lossy().into_owned();
match ServerErrorField::try_from(field.etype()) {
Ok(ServerErrorField::Code) => code = value,
Ok(ServerErrorField::Message) => message = value,
Ok(ServerErrorField::SeverityNonLocalized) => {
match PgServerErrorField::try_from(field.etype()) {
Ok(PgServerErrorField::Code) => code = value,
Ok(PgServerErrorField::Message) => message = value,
Ok(PgServerErrorField::SeverityNonLocalized) => {
severity = PgErrorSeverity::from_str(&value).unwrap_or_default()
}
Ok(field_type) => {
Expand Down Expand Up @@ -364,7 +363,7 @@ impl From<ErrorResponse<'_>> for PgServerError {
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::TryFrom)]
#[try_from(repr)]
pub enum ServerErrorField {
pub enum PgServerErrorField {
/// Severity: ERROR, FATAL, PANIC, WARNING, NOTICE, DEBUG, INFO, or LOG
Severity = b'S',
/// Severity (non-localized): ERROR, FATAL, PANIC, WARNING, NOTICE, DEBUG, INFO, or LOG
Expand Down
3 changes: 2 additions & 1 deletion edb/server/pgrust/src/handshake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ mod tests {
},
ConnectionSslRequirement::Disable,
);
let mut server = server::ServerState::new(ConnectionSslRequirement::Disable);
let mut server =
server::ServerState::new(ConnectionSslRequirement::Disable, 0x1234, 0x4321);

// We test all variations here, but not all combinations will result in
// valid auth, even with a correct password.
Expand Down
35 changes: 20 additions & 15 deletions edb/server/pgrust/src/handshake/server_state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
connection::ConnectionError,
errors::{
PgError, PgErrorConnectionException, PgErrorFeatureNotSupported,
PgErrorInvalidAuthorizationSpecification, PgServerError, ServerErrorField,
PgErrorInvalidAuthorizationSpecification, PgServerError, PgServerErrorField,
},
handshake::AuthType,
protocol::{
Expand All @@ -13,7 +13,7 @@ use crate::{
},
};
use rand::Rng;
use std::{collections::HashMap, str::Utf8Error};
use std::str::Utf8Error;
use tracing::{error, trace, warn};

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -78,7 +78,6 @@ enum PredeterminedResult {
#[derive(Debug)]
struct ServerEnvironmentImpl {
ssl_requirement: ConnectionSslRequirement,
parameters: HashMap<String, String>,
pid: i32,
key: i32,
}
Expand Down Expand Up @@ -139,7 +138,7 @@ enum ServerStateImpl {
/// Password-based authentication in progress
AuthenticatingPassword(String, CredentialData),
/// MD5 authentication in progress
AuthenticatingMD5(String, Option<PredeterminedResult>, StoredHash),
AuthenticatingMD5(Option<PredeterminedResult>, StoredHash),
/// SASL authentication in progress
AuthenticatingSASL(ServerTransaction, Option<PredeterminedResult>, StoredKey),
/// Synchronizing connection parameters
Expand All @@ -165,19 +164,19 @@ fn send_error(
update.send(BackendBuilder::ErrorResponse(builder::ErrorResponse {
fields: &[
builder::ErrorField {
etype: ServerErrorField::Severity as _,
etype: PgServerErrorField::Severity as _,
value: "ERROR",
},
builder::ErrorField {
etype: ServerErrorField::SeverityNonLocalized as _,
etype: PgServerErrorField::SeverityNonLocalized as _,
value: "ERROR",
},
builder::ErrorField {
etype: ServerErrorField::Code as _,
etype: PgServerErrorField::Code as _,
value: std::str::from_utf8(&code.to_code()).unwrap(),
},
builder::ErrorField {
etype: ServerErrorField::Message as _,
etype: PgServerErrorField::Message as _,
value: message,
},
],
Expand Down Expand Up @@ -206,14 +205,13 @@ const PROTOCOL_VERSION_ERROR: ServerError = ServerError::Protocol(PgError::Featu
));

impl ServerState {
pub fn new(ssl_requirement: ConnectionSslRequirement) -> Self {
pub fn new(ssl_requirement: ConnectionSslRequirement, pid: i32, key: i32) -> Self {
Self {
state: ServerStateImpl::Initial(false),
environment: ServerEnvironmentImpl {
ssl_requirement,
parameters: HashMap::new(),
pid: rand::thread_rng().gen(),
key: rand::thread_rng().gen(),
pid,
key,
},
}
}
Expand Down Expand Up @@ -332,7 +330,7 @@ impl ServerState {
}
};
let salt = hash.salt;
self.state = AuthenticatingMD5(username, result, hash);
self.state = AuthenticatingMD5(result, hash);
update.send(BackendBuilder::AuthenticationMD5Password(
builder::AuthenticationMD5Password { salt: salt },
))?;
Expand Down Expand Up @@ -411,7 +409,7 @@ impl ServerState {
}
});
}
(AuthenticatingMD5(username, results, md5), ConnectionDrive::Message(message)) => {
(AuthenticatingMD5(results, md5), ConnectionDrive::Message(message)) => {
match_message!(message, Message {
(PasswordMessage as password) => {
let password = password.password();
Expand Down Expand Up @@ -450,7 +448,10 @@ impl ServerState {
}))?;
},
Ok(None) => return Err(PASSWORD_ERROR),
Err(e) => return Err(PASSWORD_ERROR),
Err(e) => {
error!("SCRAM auth failed: {e:?}");
return Err(PASSWORD_ERROR);
}
}
},
unknown => {
Expand Down Expand Up @@ -493,6 +494,10 @@ impl ServerState {
}))?;
}
(Synchronizing, ConnectionDrive::Ready) => {
update.send(BackendBuilder::BackendKeyData(builder::BackendKeyData {
key: self.environment.key,
pid: self.environment.pid,
}))?;
update.send(BackendBuilder::ReadyForQuery(builder::ReadyForQuery {
status: b'I',
}))?;
Expand Down

0 comments on commit 65208d9

Please sign in to comment.