Skip to content

Commit

Permalink
Server-side SCRAM should work with stored keys, not hashed keys (#7830)
Browse files Browse the repository at this point in the history
In preparation for landing a new Rust front-end, implement the
server-side SCRAM correctly by using stored/server keys like Postgres
does (ie:
`SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>`).

This code is not currently used.
  • Loading branch information
mmastrac authored Oct 3, 2024
1 parent cdab0ef commit 47f66ca
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 59 deletions.
3 changes: 1 addition & 2 deletions edb/server/pgrust/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ mod stringprep_table;

pub use md5::md5_password;
pub use scram::{
generate_salted_password, generate_stored_key, ClientEnvironment, ClientTransaction,
SCRAMError, Sha256Out,
generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError, Sha256Out,
};
pub use stringprep::{sasl_normalize_password, sasl_normalize_password_bytes};
259 changes: 202 additions & 57 deletions edb/server/pgrust/src/auth/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@

use base64::{prelude::BASE64_STANDARD, Engine};
use hmac::{Hmac, Mac};
use rand::Rng;
use sha2::{digest::FixedOutput, Digest, Sha256};
use std::borrow::Cow;
use std::str::FromStr;

use super::sasl_normalize_password_bytes;

Expand All @@ -87,8 +89,8 @@ pub enum SCRAMError {
}

pub trait ServerEnvironment {
fn get_password_parameters(&self, username: &str) -> (Cow<'static, str>, usize);
fn get_salted_password(&self, username: &str) -> Sha256Out;
fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize);
fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out);
fn generate_nonce(&self) -> String;
}

Expand Down Expand Up @@ -122,7 +124,7 @@ impl ServerTransaction {
nonce += &env.generate_nonce();
let response = ServerFirstResponse {
combined_nonce: nonce.to_string().into(),
salt,
salt: BASE64_STANDARD.encode(salt).into(),
iterations,
};
self.state =
Expand All @@ -137,23 +139,30 @@ impl ServerTransaction {
if message.channel_binding != CHANNEL_BINDING_ENCODED {
return Err(SCRAMError::ProtocolError);
}
let salted_password = env.get_salted_password(&first_message.username);
let (client_proof, server_verifier) = generate_proof(
let (stored_key, server_key) = env.get_stored_key(&first_message.username);

// Decode the provided client proof
let mut provided_proof = vec![];
BASE64_STANDARD
.decode_vec(message.proof.as_bytes(), &mut provided_proof)
.map_err(|_| SCRAMError::ProtocolError)?;

let (calculated_stored_key, server_signature) = generate_server_proof(
first_message.encode().as_bytes(),
first_response.encode().as_bytes(),
message.channel_binding.as_bytes(),
message.combined_nonce.as_bytes(),
&salted_password,
&provided_proof,
&server_key,
&stored_key,
);
let mut proof = vec![];
BASE64_STANDARD
.decode_vec(message.proof.as_bytes(), &mut proof)
.map_err(|_| SCRAMError::ProtocolError)?;
if proof != client_proof {

if calculated_stored_key.as_slice() != stored_key {
return Err(SCRAMError::ProtocolError);
}

self.state = ServerState::Success;
let verifier = BASE64_STANDARD.encode(server_verifier).into();
let verifier = BASE64_STANDARD.encode(server_signature).into();
Ok(Some(ServerFinalResponse { verifier }.encode().into_bytes()))
}
}
Expand Down Expand Up @@ -224,7 +233,7 @@ impl ClientTransaction {
let mut buffer = [0; 1024];
let salt = decode_salt(&message.salt, &mut buffer)?;
let salted_password = env.get_salted_password(&salt, message.iterations);
let (client_proof, server_verifier) = generate_proof(
let (client_proof, server_verifier) = generate_client_proof(
first_message.encode().as_bytes(),
message.encode().as_bytes(),
CHANNEL_BINDING_ENCODED.as_bytes(),
Expand Down Expand Up @@ -519,32 +528,123 @@ pub fn generate_salted_password(password: &[u8], salt: &[u8], iterations: usize)
u.as_slice().try_into().unwrap()
}

/// Generate a stored key compatible with PostgreSQL's encoding.
pub fn generate_stored_key(password: &[u8], salt: &[u8], iterations: usize) -> String {
let digest_key = generate_salted_password(password, salt, iterations);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct StoredKey {
pub iterations: usize,
pub salt: Vec<u8>,
pub stored_key: Sha256Out,
pub server_key: Sha256Out,
}

let mut client_key = hmac(&digest_key)
.chain_update(b"Client Key")
.finalize()
.into_bytes();
impl FromStr for StoredKey {
type Err = SCRAMError;

let stored_key = Sha256::digest(client_key);
// "SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>"

let server_key = hmac(&digest_key)
.chain_update(b"Server Key")
.finalize()
.into_bytes();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Vec<&str> = s.split('$').collect();
if parts.len() != 3 || parts[0] != "SCRAM-SHA-256" {
return Err(SCRAMError::ProtocolError);
}

let iterations = parts[1]
.split(':')
.next()
.ok_or(SCRAMError::ProtocolError)?
.parse()
.map_err(|_| SCRAMError::ProtocolError)?;

let salt = BASE64_STANDARD
.decode(
parts[1]
.split(':')
.nth(1)
.ok_or(SCRAMError::ProtocolError)?,
)
.map_err(|_| SCRAMError::ProtocolError)?;

let key_parts: Vec<&str> = parts[2].split(':').collect();
if key_parts.len() != 2 {
return Err(SCRAMError::ProtocolError);
}

let stored_key = BASE64_STANDARD
.decode(key_parts[0])
.map_err(|_| SCRAMError::ProtocolError)?
.try_into()
.map_err(|_| SCRAMError::ProtocolError)?;

format!(
"SCRAM-SHA-256${}:{}${}:{}",
iterations,
BASE64_STANDARD.encode(salt),
BASE64_STANDARD.encode(stored_key),
BASE64_STANDARD.encode(server_key)
)
let server_key = BASE64_STANDARD
.decode(key_parts[1])
.map_err(|_| SCRAMError::ProtocolError)?
.try_into()
.map_err(|_| SCRAMError::ProtocolError)?;

Ok(StoredKey {
iterations,
salt,
stored_key,
server_key,
})
}
}
use std::fmt;

impl fmt::Display for StoredKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SCRAM-SHA-256${}:{}${}:{}",
self.iterations,
BASE64_STANDARD.encode(&self.salt),
BASE64_STANDARD.encode(self.stored_key),
BASE64_STANDARD.encode(self.server_key)
)
}
}

impl ServerEnvironment for StoredKey {
fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize) {
(Cow::Owned(self.salt.clone()), self.iterations)
}

fn generate_nonce(&self) -> String {
let nonce: [u8; 32] = rand::thread_rng().gen();
base64::engine::general_purpose::STANDARD.encode(nonce)
}

fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out) {
(self.stored_key, self.server_key)
}
}

fn generate_proof(
impl StoredKey {
/// Generate a stored key compatible with PostgreSQL's encoding.
pub fn generate(password: &[u8], salt: &[u8], iterations: usize) -> Self {
let digest_key = generate_salted_password(password, salt, iterations);

let client_key = hmac(&digest_key)
.chain_update(b"Client Key")
.finalize()
.into_bytes();

let stored_key = Sha256::digest(client_key);

let server_key = hmac(&digest_key)
.chain_update(b"Server Key")
.finalize()
.into_bytes();

Self {
iterations,
salt: salt.to_owned(),
stored_key: stored_key.into(),
server_key: server_key.into(),
}
}
}

fn generate_client_proof(
first_message_bare: &[u8],
server_first_message: &[u8],
channel_binding: &[u8],
Expand Down Expand Up @@ -596,6 +696,51 @@ fn generate_proof(
(client_signature, server_proof)
}

fn generate_server_proof(
first_message_bare: &[u8],
server_first_message: &[u8],
channel_binding: &[u8],
server_nonce: &[u8],
provided_proof: &[u8],
server_key: &[u8],
stored_key: &[u8],
) -> (Sha256Out, Sha256Out) {
let auth_message = [
(first_message_bare),
(b","),
(server_first_message),
(b",c="),
(channel_binding),
(b",r="),
(server_nonce),
];

let mut client_signature = hmac(stored_key);
for chunk in &auth_message {
client_signature.update(chunk);
}
let client_signature = client_signature.finalize_fixed();

let mut calculated_stored_key = [0u8; 32];
for (i, (&p, &c)) in provided_proof
.iter()
.zip(client_signature.iter())
.enumerate()
{
calculated_stored_key[i] = p ^ c;
}

let calculated_stored_key = Sha256::digest(calculated_stored_key);

let mut server_signature = hmac(server_key);
for chunk in &auth_message {
server_signature.update(chunk);
}
let server_signature = server_signature.finalize_fixed();

(calculated_stored_key.into(), server_signature.into())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -696,28 +841,19 @@ mod tests {
// Prohibited char (ffff)
#[case(b"\xef\xbf\xbf", "SCRAM-SHA-256$4096:Tdv5eCJIm+LU9QJBKO96gQ==$YXE4G3HKPwCmwo4FjiFKaiqVGCDTOpVETv+Fe6wWY9Q=:DK7MZ/OgGGgCDh6EfsmmcyFuaAD+T2Zh78sl+QDQFIo=")]
fn test_stored_key(#[case] password: &[u8], #[case] stored_key: &str) {
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine};
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};

let salt = stored_key
.split(':')
.nth(1)
.unwrap()
.split('$')
.next()
.unwrap();
let salt = BASE64_STANDARD.decode(salt).unwrap();
let generated_key = generate_stored_key(password, &salt, 4096);
assert_eq!(generated_key, stored_key);
let parsed_key = StoredKey::from_str(stored_key).unwrap();
assert_eq!(4096, parsed_key.iterations);
let generated_key = StoredKey::generate(password, &parsed_key.salt, parsed_key.iterations);
assert_eq!(generated_key, parsed_key);
assert_eq!(generated_key.to_string(), stored_key);
}

#[test]
fn test_client_proof() {
let mut buffer = [0; 128];
let salt = decode_salt(SALT, &mut buffer).unwrap();
let salted_password = generate_salted_password(PASSWORD, &salt, ITERATIONS);
let (client, server) = generate_proof(
let (client, server) = generate_client_proof(
format!("n={USERNAME},r={CLIENT_NONCE}").as_bytes(),
format!("r={CLIENT_NONCE}{SERVER_NONCE},s={SALT},i={ITERATIONS}").as_bytes(),
CHANNEL_BINDING_ENCODED.as_bytes(),
Expand Down Expand Up @@ -829,28 +965,37 @@ mod tests {
}
}
impl ServerEnvironment for Env {
fn get_salted_password(&self, username: &str) -> Sha256Out {
fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out) {
assert_eq!(username, "username");
generate_salted_password(b"password", b"hello", 4096)
let key = StoredKey::generate(b"password", b"hello", 4096);
(key.stored_key, key.server_key)
}
fn generate_nonce(&self) -> String {
"<<<server nonce>>>".into()
}
fn get_password_parameters(&self, username: &str) -> (Cow<'static, str>, usize) {
fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize) {
assert_eq!(username, "username");
(Cow::Borrowed("aGVsbG8="), 4096)
(Cow::Borrowed(b"hello"), 4096)
}
}
let env = Env {};

let message = client.process_message(&[], &env).unwrap().unwrap();
eprintln!("client: {:?}", String::from_utf8(message.clone()).unwrap());
assert_eq!(
String::from_utf8(message.clone()).unwrap(),
"n,,n=username,r=<<<client nonce>>>"
);
let message = server.process_message(&message, &env).unwrap().unwrap();
eprintln!("server: {:?}", String::from_utf8(message.clone()).unwrap());
assert_eq!(
String::from_utf8(message.clone()).unwrap(),
"r=<<<client nonce>>><<<server nonce>>>,s=aGVsbG8=,i=4096"
);
let message = client.process_message(&message, &env).unwrap().unwrap();
eprintln!("client: {:?}", String::from_utf8(message.clone()).unwrap());
assert_eq!(String::from_utf8(message.clone()).unwrap(), "c=biws,r=<<<client nonce>>><<<server nonce>>>,p=621h6u6V3axb7mNYHNgTspTZ3SqILcxuJOsFu5wMjV8=");
let message = server.process_message(&message, &env).unwrap().unwrap();
eprintln!("server: {:?}", String::from_utf8(message.clone()).unwrap());
assert_eq!(
String::from_utf8(message.clone()).unwrap(),
"v=moj4kNnZKB3wjXZeQsKYI9luTTakwgH8r0NdGOjugRY="
);
assert!(client.process_message(&message, &env).unwrap().is_none());
assert!(client.success());
assert!(server.success());
Expand Down

0 comments on commit 47f66ca

Please sign in to comment.