From 47f66ca3f7edcaf50dc840ab424d16fff876cb85 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Thu, 3 Oct 2024 17:42:07 -0600 Subject: [PATCH] Server-side SCRAM should work with stored keys, not hashed keys (#7830) In preparation for landing a new Rust front-end, implement the server-side SCRAM correctly by using stored/server keys like Postgres does (ie: `SCRAM-SHA-256$:$:`). This code is not currently used. --- edb/server/pgrust/src/auth/mod.rs | 3 +- edb/server/pgrust/src/auth/scram.rs | 259 ++++++++++++++++++++++------ 2 files changed, 203 insertions(+), 59 deletions(-) diff --git a/edb/server/pgrust/src/auth/mod.rs b/edb/server/pgrust/src/auth/mod.rs index ae3c8891e3d..fdc0810d141 100644 --- a/edb/server/pgrust/src/auth/mod.rs +++ b/edb/server/pgrust/src/auth/mod.rs @@ -5,7 +5,6 @@ mod stringprep_table; pub use md5::md5_password; pub use scram::{ - generate_salted_password, generate_stored_key, ClientEnvironment, ClientTransaction, - SCRAMError, Sha256Out, + generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError, Sha256Out, }; pub use stringprep::{sasl_normalize_password, sasl_normalize_password_bytes}; diff --git a/edb/server/pgrust/src/auth/scram.rs b/edb/server/pgrust/src/auth/scram.rs index 9bebf8e93ee..4f326aa88a9 100644 --- a/edb/server/pgrust/src/auth/scram.rs +++ b/edb/server/pgrust/src/auth/scram.rs @@ -69,8 +69,10 @@ use base64::{prelude::BASE64_STANDARD, Engine}; use hmac::{Hmac, Mac}; +use rand::Rng; use sha2::{digest::FixedOutput, Digest, Sha256}; use std::borrow::Cow; +use std::str::FromStr; use super::sasl_normalize_password_bytes; @@ -87,8 +89,8 @@ pub enum SCRAMError { } pub trait ServerEnvironment { - fn get_password_parameters(&self, username: &str) -> (Cow<'static, str>, usize); - fn get_salted_password(&self, username: &str) -> Sha256Out; + fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize); + fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out); fn generate_nonce(&self) -> String; } @@ -122,7 +124,7 @@ impl ServerTransaction { nonce += &env.generate_nonce(); let response = ServerFirstResponse { combined_nonce: nonce.to_string().into(), - salt, + salt: BASE64_STANDARD.encode(salt).into(), iterations, }; self.state = @@ -137,23 +139,30 @@ impl ServerTransaction { if message.channel_binding != CHANNEL_BINDING_ENCODED { return Err(SCRAMError::ProtocolError); } - let salted_password = env.get_salted_password(&first_message.username); - let (client_proof, server_verifier) = generate_proof( + let (stored_key, server_key) = env.get_stored_key(&first_message.username); + + // Decode the provided client proof + let mut provided_proof = vec![]; + BASE64_STANDARD + .decode_vec(message.proof.as_bytes(), &mut provided_proof) + .map_err(|_| SCRAMError::ProtocolError)?; + + let (calculated_stored_key, server_signature) = generate_server_proof( first_message.encode().as_bytes(), first_response.encode().as_bytes(), message.channel_binding.as_bytes(), message.combined_nonce.as_bytes(), - &salted_password, + &provided_proof, + &server_key, + &stored_key, ); - let mut proof = vec![]; - BASE64_STANDARD - .decode_vec(message.proof.as_bytes(), &mut proof) - .map_err(|_| SCRAMError::ProtocolError)?; - if proof != client_proof { + + if calculated_stored_key.as_slice() != stored_key { return Err(SCRAMError::ProtocolError); } + self.state = ServerState::Success; - let verifier = BASE64_STANDARD.encode(server_verifier).into(); + let verifier = BASE64_STANDARD.encode(server_signature).into(); Ok(Some(ServerFinalResponse { verifier }.encode().into_bytes())) } } @@ -224,7 +233,7 @@ impl ClientTransaction { let mut buffer = [0; 1024]; let salt = decode_salt(&message.salt, &mut buffer)?; let salted_password = env.get_salted_password(&salt, message.iterations); - let (client_proof, server_verifier) = generate_proof( + let (client_proof, server_verifier) = generate_client_proof( first_message.encode().as_bytes(), message.encode().as_bytes(), CHANNEL_BINDING_ENCODED.as_bytes(), @@ -519,32 +528,123 @@ pub fn generate_salted_password(password: &[u8], salt: &[u8], iterations: usize) u.as_slice().try_into().unwrap() } -/// Generate a stored key compatible with PostgreSQL's encoding. -pub fn generate_stored_key(password: &[u8], salt: &[u8], iterations: usize) -> String { - let digest_key = generate_salted_password(password, salt, iterations); +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct StoredKey { + pub iterations: usize, + pub salt: Vec, + pub stored_key: Sha256Out, + pub server_key: Sha256Out, +} - let mut client_key = hmac(&digest_key) - .chain_update(b"Client Key") - .finalize() - .into_bytes(); +impl FromStr for StoredKey { + type Err = SCRAMError; - let stored_key = Sha256::digest(client_key); + // "SCRAM-SHA-256$:$:" - let server_key = hmac(&digest_key) - .chain_update(b"Server Key") - .finalize() - .into_bytes(); + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.split('$').collect(); + if parts.len() != 3 || parts[0] != "SCRAM-SHA-256" { + return Err(SCRAMError::ProtocolError); + } + + let iterations = parts[1] + .split(':') + .next() + .ok_or(SCRAMError::ProtocolError)? + .parse() + .map_err(|_| SCRAMError::ProtocolError)?; + + let salt = BASE64_STANDARD + .decode( + parts[1] + .split(':') + .nth(1) + .ok_or(SCRAMError::ProtocolError)?, + ) + .map_err(|_| SCRAMError::ProtocolError)?; + + let key_parts: Vec<&str> = parts[2].split(':').collect(); + if key_parts.len() != 2 { + return Err(SCRAMError::ProtocolError); + } + + let stored_key = BASE64_STANDARD + .decode(key_parts[0]) + .map_err(|_| SCRAMError::ProtocolError)? + .try_into() + .map_err(|_| SCRAMError::ProtocolError)?; - format!( - "SCRAM-SHA-256${}:{}${}:{}", - iterations, - BASE64_STANDARD.encode(salt), - BASE64_STANDARD.encode(stored_key), - BASE64_STANDARD.encode(server_key) - ) + let server_key = BASE64_STANDARD + .decode(key_parts[1]) + .map_err(|_| SCRAMError::ProtocolError)? + .try_into() + .map_err(|_| SCRAMError::ProtocolError)?; + + Ok(StoredKey { + iterations, + salt, + stored_key, + server_key, + }) + } +} +use std::fmt; + +impl fmt::Display for StoredKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "SCRAM-SHA-256${}:{}${}:{}", + self.iterations, + BASE64_STANDARD.encode(&self.salt), + BASE64_STANDARD.encode(self.stored_key), + BASE64_STANDARD.encode(self.server_key) + ) + } +} + +impl ServerEnvironment for StoredKey { + fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize) { + (Cow::Owned(self.salt.clone()), self.iterations) + } + + fn generate_nonce(&self) -> String { + let nonce: [u8; 32] = rand::thread_rng().gen(); + base64::engine::general_purpose::STANDARD.encode(nonce) + } + + fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out) { + (self.stored_key, self.server_key) + } } -fn generate_proof( +impl StoredKey { + /// Generate a stored key compatible with PostgreSQL's encoding. + pub fn generate(password: &[u8], salt: &[u8], iterations: usize) -> Self { + let digest_key = generate_salted_password(password, salt, iterations); + + let client_key = hmac(&digest_key) + .chain_update(b"Client Key") + .finalize() + .into_bytes(); + + let stored_key = Sha256::digest(client_key); + + let server_key = hmac(&digest_key) + .chain_update(b"Server Key") + .finalize() + .into_bytes(); + + Self { + iterations, + salt: salt.to_owned(), + stored_key: stored_key.into(), + server_key: server_key.into(), + } + } +} + +fn generate_client_proof( first_message_bare: &[u8], server_first_message: &[u8], channel_binding: &[u8], @@ -596,6 +696,51 @@ fn generate_proof( (client_signature, server_proof) } +fn generate_server_proof( + first_message_bare: &[u8], + server_first_message: &[u8], + channel_binding: &[u8], + server_nonce: &[u8], + provided_proof: &[u8], + server_key: &[u8], + stored_key: &[u8], +) -> (Sha256Out, Sha256Out) { + let auth_message = [ + (first_message_bare), + (b","), + (server_first_message), + (b",c="), + (channel_binding), + (b",r="), + (server_nonce), + ]; + + let mut client_signature = hmac(stored_key); + for chunk in &auth_message { + client_signature.update(chunk); + } + let client_signature = client_signature.finalize_fixed(); + + let mut calculated_stored_key = [0u8; 32]; + for (i, (&p, &c)) in provided_proof + .iter() + .zip(client_signature.iter()) + .enumerate() + { + calculated_stored_key[i] = p ^ c; + } + + let calculated_stored_key = Sha256::digest(calculated_stored_key); + + let mut server_signature = hmac(server_key); + for chunk in &auth_message { + server_signature.update(chunk); + } + let server_signature = server_signature.finalize_fixed(); + + (calculated_stored_key.into(), server_signature.into()) +} + #[cfg(test)] mod tests { use super::*; @@ -696,20 +841,11 @@ mod tests { // Prohibited char (ffff) #[case(b"\xef\xbf\xbf", "SCRAM-SHA-256$4096:Tdv5eCJIm+LU9QJBKO96gQ==$YXE4G3HKPwCmwo4FjiFKaiqVGCDTOpVETv+Fe6wWY9Q=:DK7MZ/OgGGgCDh6EfsmmcyFuaAD+T2Zh78sl+QDQFIo=")] fn test_stored_key(#[case] password: &[u8], #[case] stored_key: &str) { - use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine}; - use hmac::{Hmac, Mac}; - use sha2::{Digest, Sha256}; - - let salt = stored_key - .split(':') - .nth(1) - .unwrap() - .split('$') - .next() - .unwrap(); - let salt = BASE64_STANDARD.decode(salt).unwrap(); - let generated_key = generate_stored_key(password, &salt, 4096); - assert_eq!(generated_key, stored_key); + let parsed_key = StoredKey::from_str(stored_key).unwrap(); + assert_eq!(4096, parsed_key.iterations); + let generated_key = StoredKey::generate(password, &parsed_key.salt, parsed_key.iterations); + assert_eq!(generated_key, parsed_key); + assert_eq!(generated_key.to_string(), stored_key); } #[test] @@ -717,7 +853,7 @@ mod tests { let mut buffer = [0; 128]; let salt = decode_salt(SALT, &mut buffer).unwrap(); let salted_password = generate_salted_password(PASSWORD, &salt, ITERATIONS); - let (client, server) = generate_proof( + let (client, server) = generate_client_proof( format!("n={USERNAME},r={CLIENT_NONCE}").as_bytes(), format!("r={CLIENT_NONCE}{SERVER_NONCE},s={SALT},i={ITERATIONS}").as_bytes(), CHANNEL_BINDING_ENCODED.as_bytes(), @@ -829,28 +965,37 @@ mod tests { } } impl ServerEnvironment for Env { - fn get_salted_password(&self, username: &str) -> Sha256Out { + fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out) { assert_eq!(username, "username"); - generate_salted_password(b"password", b"hello", 4096) + let key = StoredKey::generate(b"password", b"hello", 4096); + (key.stored_key, key.server_key) } fn generate_nonce(&self) -> String { "<<>>".into() } - fn get_password_parameters(&self, username: &str) -> (Cow<'static, str>, usize) { + fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize) { assert_eq!(username, "username"); - (Cow::Borrowed("aGVsbG8="), 4096) + (Cow::Borrowed(b"hello"), 4096) } } let env = Env {}; - let message = client.process_message(&[], &env).unwrap().unwrap(); - eprintln!("client: {:?}", String::from_utf8(message.clone()).unwrap()); + assert_eq!( + String::from_utf8(message.clone()).unwrap(), + "n,,n=username,r=<<>>" + ); let message = server.process_message(&message, &env).unwrap().unwrap(); - eprintln!("server: {:?}", String::from_utf8(message.clone()).unwrap()); + assert_eq!( + String::from_utf8(message.clone()).unwrap(), + "r=<<>><<>>,s=aGVsbG8=,i=4096" + ); let message = client.process_message(&message, &env).unwrap().unwrap(); - eprintln!("client: {:?}", String::from_utf8(message.clone()).unwrap()); + assert_eq!(String::from_utf8(message.clone()).unwrap(), "c=biws,r=<<>><<>>,p=621h6u6V3axb7mNYHNgTspTZ3SqILcxuJOsFu5wMjV8="); let message = server.process_message(&message, &env).unwrap().unwrap(); - eprintln!("server: {:?}", String::from_utf8(message.clone()).unwrap()); + assert_eq!( + String::from_utf8(message.clone()).unwrap(), + "v=moj4kNnZKB3wjXZeQsKYI9luTTakwgH8r0NdGOjugRY=" + ); assert!(client.process_message(&message, &env).unwrap().is_none()); assert!(client.success()); assert!(server.success());