diff --git a/src/auth_passthrough.rs b/src/auth_passthrough.rs index 159847ed..53ef93d4 100644 --- a/src/auth_passthrough.rs +++ b/src/auth_passthrough.rs @@ -1,3 +1,4 @@ +use crate::config::AuthType; use crate::errors::Error; use crate::pool::ConnectionPool; use crate::server::Server; @@ -71,6 +72,7 @@ impl AuthPassthrough { pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result { let auth_user = crate::config::User { username: self.user.clone(), + auth_type: AuthType::MD5, password: Some(self.password.clone()), server_username: None, server_password: None, diff --git a/src/client.rs b/src/client.rs index 23392b73..405d72be 100644 --- a/src/client.rs +++ b/src/client.rs @@ -14,7 +14,9 @@ use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_parameters_for_admin, handle_admin}; use crate::auth_passthrough::refetch_auth_hash; -use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; +use crate::config::{ + get_config, get_idle_client_in_transaction_timeout, Address, AuthType, PoolMode, +}; use crate::constants::*; use crate::messages::*; use crate::plugins::PluginOutput; @@ -463,8 +465,8 @@ where .count() == 1; - // Kick any client that's not admin while we're in admin-only mode. if !admin && admin_only { + // Kick any client that's not admin while we're in admin-only mode. debug!( "Rejecting non-admin connection to {} when in admin only mode", pool_name @@ -481,72 +483,76 @@ where let process_id: i32 = rand::random(); let secret_key: i32 = rand::random(); - // Perform MD5 authentication. - // TODO: Add SASL support. - let salt = md5_challenge(&mut write).await?; - - let code = match read.read_u8().await { - Ok(p) => p, - Err(_) => { - return Err(Error::ClientSocketError( - "password code".into(), - client_identifier, - )) - } - }; - - // PasswordMessage - if code as char != 'p' { - return Err(Error::ProtocolSyncError(format!( - "Expected p, got {}", - code as char - ))); - } - - let len = match read.read_i32().await { - Ok(len) => len, - Err(_) => { - return Err(Error::ClientSocketError( - "password message length".into(), - client_identifier, - )) - } - }; - - let mut password_response = vec![0u8; (len - 4) as usize]; - - match read.read_exact(&mut password_response).await { - Ok(_) => (), - Err(_) => { - return Err(Error::ClientSocketError( - "password message".into(), - client_identifier, - )) - } - }; - let mut prepared_statements_enabled = false; // Authenticate admin user. let (transaction_mode, mut server_parameters) = if admin { let config = get_config(); + // TODO: Add SASL support. + // Perform MD5 authentication. + match config.general.admin_auth_type { + AuthType::Trust => (), + AuthType::MD5 => { + let salt = md5_challenge(&mut write).await?; + + let code = match read.read_u8().await { + Ok(p) => p, + Err(_) => { + return Err(Error::ClientSocketError( + "password code".into(), + client_identifier, + )) + } + }; + + // PasswordMessage + if code as char != 'p' { + return Err(Error::ProtocolSyncError(format!( + "Expected p, got {}", + code as char + ))); + } - // Compare server and client hashes. - let password_hash = md5_hash_password( - &config.general.admin_username, - &config.general.admin_password, - &salt, - ); + let len = match read.read_i32().await { + Ok(len) => len, + Err(_) => { + return Err(Error::ClientSocketError( + "password message length".into(), + client_identifier, + )) + } + }; - if password_hash != password_response { - let error = Error::ClientGeneralError("Invalid password".into(), client_identifier); + let mut password_response = vec![0u8; (len - 4) as usize]; - warn!("{}", error); - wrong_password(&mut write, username).await?; + match read.read_exact(&mut password_response).await { + Ok(_) => (), + Err(_) => { + return Err(Error::ClientSocketError( + "password message".into(), + client_identifier, + )) + } + }; - return Err(error); - } + // Compare server and client hashes. + let password_hash = md5_hash_password( + &config.general.admin_username, + &config.general.admin_password, + &salt, + ); + + if password_hash != password_response { + let error = + Error::ClientGeneralError("Invalid password".into(), client_identifier); + warn!("{}", error); + wrong_password(&mut write, username).await?; + + return Err(error); + } + } + } (false, generate_server_parameters_for_admin()) } // Authenticate normal user. @@ -573,92 +579,143 @@ where // Obtain the hash to compare, we give preference to that written in cleartext in config // if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained // when the pool was created. If there is no hash there, we try to fetch it one more time. - let password_hash = if let Some(password) = &pool.settings.user.password { - Some(md5_hash_password(username, password, &salt)) - } else { - if !get_config().is_auth_query_configured() { - wrong_password(&mut write, username).await?; - return Err(Error::ClientAuthImpossible(username.into())); - } - - let mut hash = (*pool.auth_hash.read()).clone(); - - if hash.is_none() { - warn!( - "Query auth configured \ - but no hash password found \ - for pool {}. Will try to refetch it.", - pool_name - ); + match pool.settings.user.auth_type { + AuthType::Trust => (), + AuthType::MD5 => { + // Perform MD5 authentication. + // TODO: Add SASL support. + let salt = md5_challenge(&mut write).await?; + + let code = match read.read_u8().await { + Ok(p) => p, + Err(_) => { + return Err(Error::ClientSocketError( + "password code".into(), + client_identifier, + )) + } + }; + + // PasswordMessage + if code as char != 'p' { + return Err(Error::ProtocolSyncError(format!( + "Expected p, got {}", + code as char + ))); + } - match refetch_auth_hash(&pool).await { - Ok(fetched_hash) => { - warn!("Password for {}, obtained. Updating.", client_identifier); + let len = match read.read_i32().await { + Ok(len) => len, + Err(_) => { + return Err(Error::ClientSocketError( + "password message length".into(), + client_identifier, + )) + } + }; - { - let mut pool_auth_hash = pool.auth_hash.write(); - *pool_auth_hash = Some(fetched_hash.clone()); - } + let mut password_response = vec![0u8; (len - 4) as usize]; - hash = Some(fetched_hash); + match read.read_exact(&mut password_response).await { + Ok(_) => (), + Err(_) => { + return Err(Error::ClientSocketError( + "password message".into(), + client_identifier, + )) } + }; - Err(err) => { + let password_hash = if let Some(password) = &pool.settings.user.password { + Some(md5_hash_password(username, password, &salt)) + } else { + if !get_config().is_auth_query_configured() { wrong_password(&mut write, username).await?; - - return Err(Error::ClientAuthPassthroughError( - err.to_string(), - client_identifier, - )); + return Err(Error::ClientAuthImpossible(username.into())); } - } - }; - Some(md5_hash_second_pass(&hash.unwrap(), &salt)) - }; + let mut hash = (*pool.auth_hash.read()).clone(); - // Once we have the resulting hash, we compare with what the client gave us. - // If they do not match and auth query is set up, we try to refetch the hash one more time - // to see if the password has changed since the pool was created. - // - // @TODO: we could end up fetching again the same password twice (see above). - if password_hash.unwrap() != password_response { - warn!( - "Invalid password {}, will try to refetch it.", - client_identifier - ); + if hash.is_none() { + warn!( + "Query auth configured \ + but no hash password found \ + for pool {}. Will try to refetch it.", + pool_name + ); - let fetched_hash = match refetch_auth_hash(&pool).await { - Ok(fetched_hash) => fetched_hash, - Err(err) => { - wrong_password(&mut write, username).await?; + match refetch_auth_hash(&pool).await { + Ok(fetched_hash) => { + warn!( + "Password for {}, obtained. Updating.", + client_identifier + ); - return Err(err); - } - }; + { + let mut pool_auth_hash = pool.auth_hash.write(); + *pool_auth_hash = Some(fetched_hash.clone()); + } - let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); + hash = Some(fetched_hash); + } - // Ok password changed in server an auth is possible. - if new_password_hash == password_response { - warn!( - "Password for {}, changed in server. Updating.", - client_identifier - ); + Err(err) => { + wrong_password(&mut write, username).await?; - { - let mut pool_auth_hash = pool.auth_hash.write(); - *pool_auth_hash = Some(fetched_hash); + return Err(Error::ClientAuthPassthroughError( + err.to_string(), + client_identifier, + )); + } + } + }; + + Some(md5_hash_second_pass(&hash.unwrap(), &salt)) + }; + + // Once we have the resulting hash, we compare with what the client gave us. + // If they do not match and auth query is set up, we try to refetch the hash one more time + // to see if the password has changed since the pool was created. + // + // @TODO: we could end up fetching again the same password twice (see above). + if password_hash.unwrap() != password_response { + warn!( + "Invalid password {}, will try to refetch it.", + client_identifier + ); + + let fetched_hash = match refetch_auth_hash(&pool).await { + Ok(fetched_hash) => fetched_hash, + Err(err) => { + wrong_password(&mut write, username).await?; + + return Err(err); + } + }; + + let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); + + // Ok password changed in server an auth is possible. + if new_password_hash == password_response { + warn!( + "Password for {}, changed in server. Updating.", + client_identifier + ); + + { + let mut pool_auth_hash = pool.auth_hash.write(); + *pool_auth_hash = Some(fetched_hash); + } + } else { + wrong_password(&mut write, username).await?; + return Err(Error::ClientGeneralError( + "Invalid password".into(), + client_identifier, + )); + } } - } else { - wrong_password(&mut write, username).await?; - return Err(Error::ClientGeneralError( - "Invalid password".into(), - client_identifier, - )); } } - let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; prepared_statements_enabled = transaction_mode && pool.prepared_statement_cache.is_some(); diff --git a/src/config.rs b/src/config.rs index c7aaf4c3..b0d98fb5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -208,6 +208,9 @@ impl Address { pub struct User { pub username: String, pub password: Option, + + #[serde(default = "User::default_auth_type")] + pub auth_type: AuthType, pub server_username: Option, pub server_password: Option, pub pool_size: u32, @@ -225,6 +228,7 @@ impl Default for User { User { username: String::from("postgres"), password: None, + auth_type: AuthType::MD5, server_username: None, server_password: None, pool_size: 15, @@ -239,6 +243,10 @@ impl Default for User { } impl User { + pub fn default_auth_type() -> AuthType { + AuthType::MD5 + } + fn validate(&self) -> Result<(), Error> { if let Some(min_pool_size) = self.min_pool_size { if min_pool_size > self.pool_size { @@ -334,6 +342,9 @@ pub struct General { pub admin_username: String, pub admin_password: String, + #[serde(default = "General::default_admin_auth_type")] + pub admin_auth_type: AuthType, + #[serde(default = "General::default_validate_config")] pub validate_config: bool, @@ -348,6 +359,10 @@ impl General { "0.0.0.0".into() } + pub fn default_admin_auth_type() -> AuthType { + AuthType::MD5 + } + pub fn default_port() -> u16 { 5432 } @@ -456,6 +471,7 @@ impl Default for General { verify_server_certificate: false, admin_username: String::from("admin"), admin_password: String::from("admin"), + admin_auth_type: AuthType::MD5, validate_config: true, auth_query: None, auth_query_user: None, @@ -476,6 +492,15 @@ pub enum PoolMode { Session, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy, Hash)] +pub enum AuthType { + #[serde(alias = "trust", alias = "Trust")] + Trust, + + #[serde(alias = "md5", alias = "MD5")] + MD5, +} + impl std::fmt::Display for PoolMode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/tests/python/conftest.py b/tests/python/conftest.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/python/test_auth.py b/tests/python/test_auth.py new file mode 100644 index 00000000..bd943429 --- /dev/null +++ b/tests/python/test_auth.py @@ -0,0 +1,71 @@ +import utils +import signal + +class TestTrustAuth: + @classmethod + def setup_method(cls): + config= """ + [general] + host = "0.0.0.0" + port = 6432 + admin_username = "admin_user" + admin_password = "" + admin_auth_type = "trust" + + [pools.sharded_db.users.0] + username = "sharding_user" + password = "sharding_user" + auth_type = "trust" + pool_size = 10 + min_pool_size = 1 + pool_mode = "transaction" + + [pools.sharded_db.shards.0] + servers = [ + [ "127.0.0.1", 5432, "primary" ], + ] + database = "shard0" + """ + utils.pgcat_generic_start(config) + + @classmethod + def teardown_method(self): + utils.pg_cat_send_signal(signal.SIGTERM) + + def test_admin_trust_auth(self): + conn, cur = utils.connect_db_trust(admin=True) + cur.execute("SHOW POOLS") + res = cur.fetchall() + print(res) + utils.cleanup_conn(conn, cur) + + def test_normal_trust_auth(self): + conn, cur = utils.connect_db_trust(autocommit=False) + cur.execute("SELECT 1") + res = cur.fetchall() + print(res) + utils.cleanup_conn(conn, cur) + +class TestMD5Auth: + @classmethod + def setup_method(cls): + utils.pgcat_start() + + @classmethod + def teardown_method(self): + utils.pg_cat_send_signal(signal.SIGTERM) + + def test_normal_db_access(self): + conn, cur = utils.connect_db(autocommit=False) + cur.execute("SELECT 1") + res = cur.fetchall() + print(res) + utils.cleanup_conn(conn, cur) + + def test_admin_db_access(self): + conn, cur = utils.connect_db(admin=True) + + cur.execute("SHOW POOLS") + res = cur.fetchall() + print(res) + utils.cleanup_conn(conn, cur) diff --git a/tests/python/test_pgcat.py b/tests/python/test_pgcat.py index dc2f11e5..773715d4 100644 --- a/tests/python/test_pgcat.py +++ b/tests/python/test_pgcat.py @@ -1,30 +1,12 @@ -import os + import signal import time import psycopg2 - import utils SHUTDOWN_TIMEOUT = 5 -def test_normal_db_access(): - utils.pgcat_start() - conn, cur = utils.connect_db(autocommit=False) - cur.execute("SELECT 1") - res = cur.fetchall() - print(res) - utils.cleanup_conn(conn, cur) - - -def test_admin_db_access(): - conn, cur = utils.connect_db(admin=True) - - cur.execute("SHOW POOLS") - res = cur.fetchall() - print(res) - utils.cleanup_conn(conn, cur) - def test_shutdown_logic(): @@ -256,3 +238,5 @@ def test_shutdown_logic(): utils.cleanup_conn(conn, cur) utils.pg_cat_send_signal(signal.SIGTERM) + + # - - - - - - - - - - - - - - - - - - diff --git a/tests/python/utils.py b/tests/python/utils.py index 5c49bce9..9a1c6de9 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -1,20 +1,49 @@ -from typing import Tuple import os -import psutil import signal import time +from typing import Tuple +import tempfile +import psutil import psycopg2 PGCAT_HOST = "127.0.0.1" PGCAT_PORT = "6432" -def pgcat_start(): + +def _pgcat_start(config_path: str): pg_cat_send_signal(signal.SIGTERM) - os.system("./target/debug/pgcat .circleci/pgcat.toml &") + os.system(f"./target/debug/pgcat {config_path} &") time.sleep(2) +def pgcat_start(): + _pgcat_start(config_path='.circleci/pgcat.toml') + + +def pgcat_generic_start(config: str): + tmp = tempfile.NamedTemporaryFile() + with open(tmp.name, 'w') as f: + f.write(config) + _pgcat_start(config_path=tmp.name) + + +def glauth_send_signal(signal: signal.Signals): + try: + for proc in psutil.process_iter(["pid", "name"]): + if proc.name() == "glauth": + os.kill(proc.pid, signal) + except Exception as e: + # The process can be gone when we send this signal + print(e) + + if signal == signal.SIGTERM: + # Returns 0 if pgcat process exists + time.sleep(2) + if not os.system('pgrep glauth'): + raise Exception("glauth not closed after SIGTERM") + + def pg_cat_send_signal(signal: signal.Signals): try: for proc in psutil.process_iter(["pid", "name"]): @@ -54,6 +83,27 @@ def connect_db( return (conn, cur) +def connect_db_trust( + autocommit: bool = True, + admin: bool = False, +) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]: + + if admin: + user = "admin_user" + db = "pgcat" + else: + user = "sharding_user" + db = "sharded_db" + + conn = psycopg2.connect( + f"postgres://{user}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat", + connect_timeout=2, + ) + conn.autocommit = autocommit + cur = conn.cursor() + + return (conn, cur) + def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor): cur.close()