Skip to content

Commit f94ce97

Browse files
authored
Handle and track startup parameters (#478)
* User server parameters struct instead of server info bytesmut * Refactor to use hashmap for all params and add server parameters to client * Sync parameters on client server checkout * minor refactor * update client side parameters when changed * Move the SET statement logic from the C packet to the S packet. * trigger build * revert validation changes * remove comment * Try fix * Reset cleanup state after sync * fix server version test * Track application name through client life for stats * Add tests * minor refactoring * fmt * fix * fmt
1 parent 9ab1285 commit f94ce97

File tree

8 files changed

+308
-123
lines changed

8 files changed

+308
-123
lines changed

src/admin.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::pool::BanReason;
2+
use crate::server::ServerParameters;
23
use crate::stats::pool::PoolStats;
34
use bytes::{Buf, BufMut, BytesMut};
45
use log::{error, info, trace};
@@ -17,16 +18,16 @@ use crate::pool::ClientServerMap;
1718
use crate::pool::{get_all_pools, get_pool};
1819
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};
1920

20-
pub fn generate_server_info_for_admin() -> BytesMut {
21-
let mut server_info = BytesMut::new();
21+
pub fn generate_server_parameters_for_admin() -> ServerParameters {
22+
let mut server_parameters = ServerParameters::new();
2223

23-
server_info.put(server_parameter_message("application_name", ""));
24-
server_info.put(server_parameter_message("client_encoding", "UTF8"));
25-
server_info.put(server_parameter_message("server_encoding", "UTF8"));
26-
server_info.put(server_parameter_message("server_version", VERSION));
27-
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
24+
server_parameters.set_param("application_name".to_string(), "".to_string(), true);
25+
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true);
26+
server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true);
27+
server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true);
28+
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true);
2829

29-
server_info
30+
server_parameters
3031
}
3132

3233
/// Handle admin client.

src/client.rs

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use tokio::net::TcpStream;
1212
use tokio::sync::broadcast::Receiver;
1313
use tokio::sync::mpsc::Sender;
1414

15-
use crate::admin::{generate_server_info_for_admin, handle_admin};
15+
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
1616
use crate::auth_passthrough::refetch_auth_hash;
1717
use crate::config::{
1818
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
@@ -22,7 +22,7 @@ use crate::messages::*;
2222
use crate::plugins::PluginOutput;
2323
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
2424
use crate::query_router::{Command, QueryRouter};
25-
use crate::server::Server;
25+
use crate::server::{Server, ServerParameters};
2626
use crate::stats::{ClientStats, ServerStats};
2727
use crate::tls::Tls;
2828

@@ -96,8 +96,8 @@ pub struct Client<S, T> {
9696
/// Postgres user for this client (This comes from the user in the connection string)
9797
username: String,
9898

99-
/// Application name for this client (defaults to pgcat)
100-
application_name: String,
99+
/// Server startup and session parameters that we're going to track
100+
server_parameters: ServerParameters,
101101

102102
/// Used to notify clients about an impending shutdown
103103
shutdown: Receiver<()>,
@@ -502,7 +502,7 @@ where
502502
};
503503

504504
// Authenticate admin user.
505-
let (transaction_mode, server_info) = if admin {
505+
let (transaction_mode, mut server_parameters) = if admin {
506506
let config = get_config();
507507

508508
// Compare server and client hashes.
@@ -521,7 +521,7 @@ where
521521
return Err(error);
522522
}
523523

524-
(false, generate_server_info_for_admin())
524+
(false, generate_server_parameters_for_admin())
525525
}
526526
// Authenticate normal user.
527527
else {
@@ -654,13 +654,16 @@ where
654654
}
655655
}
656656

657-
(transaction_mode, pool.server_info())
657+
(transaction_mode, pool.server_parameters())
658658
};
659659

660+
// Update the parameters to merge what the application sent and what's originally on the server
661+
server_parameters.set_from_hashmap(&parameters, false);
662+
660663
debug!("Password authentication successful");
661664

662665
auth_ok(&mut write).await?;
663-
write_all(&mut write, server_info).await?;
666+
write_all(&mut write, (&server_parameters).into()).await?;
664667
backend_key_data(&mut write, process_id, secret_key).await?;
665668
ready_for_query(&mut write).await?;
666669

@@ -690,7 +693,7 @@ where
690693
last_server_stats: None,
691694
pool_name: pool_name.clone(),
692695
username: username.clone(),
693-
application_name: application_name.to_string(),
696+
server_parameters,
694697
shutdown,
695698
connected_to_server: false,
696699
prepared_statements: HashMap::new(),
@@ -725,7 +728,7 @@ where
725728
last_server_stats: None,
726729
pool_name: String::from("undefined"),
727730
username: String::from("undefined"),
728-
application_name: String::from("undefined"),
731+
server_parameters: ServerParameters::new(),
729732
shutdown,
730733
connected_to_server: false,
731734
prepared_statements: HashMap::new(),
@@ -774,8 +777,11 @@ where
774777
let mut prepared_statement = None;
775778
let mut will_prepare = false;
776779

777-
let client_identifier =
778-
ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name);
780+
let client_identifier = ClientIdentifier::new(
781+
&self.server_parameters.get_application_name(),
782+
&self.username,
783+
&self.pool_name,
784+
);
779785

780786
// Our custom protocol loop.
781787
// We expect the client to either start a transaction with regular queries
@@ -1115,10 +1121,7 @@ where
11151121
server.address()
11161122
);
11171123

1118-
// TODO: investigate other parameters and set them too.
1119-
1120-
// Set application_name.
1121-
server.set_name(&self.application_name).await?;
1124+
server.sync_parameters(&self.server_parameters).await?;
11221125

11231126
let mut initial_message = Some(message);
11241127

@@ -1296,7 +1299,9 @@ where
12961299
if !server.in_transaction() {
12971300
// Report transaction executed statistics.
12981301
self.stats.transaction();
1299-
server.stats().transaction(&self.application_name);
1302+
server
1303+
.stats()
1304+
.transaction(&self.server_parameters.get_application_name());
13001305

13011306
// Release server back to the pool if we are in transaction mode.
13021307
// If we are in session mode, we keep the server until the client disconnects.
@@ -1446,7 +1451,9 @@ where
14461451

14471452
if !server.in_transaction() {
14481453
self.stats.transaction();
1449-
server.stats().transaction(&self.application_name);
1454+
server
1455+
.stats()
1456+
.transaction(&self.server_parameters.get_application_name());
14501457

14511458
// Release server back to the pool if we are in transaction mode.
14521459
// If we are in session mode, we keep the server until the client disconnects.
@@ -1495,7 +1502,9 @@ where
14951502

14961503
if !server.in_transaction() {
14971504
self.stats.transaction();
1498-
server.stats().transaction(&self.application_name);
1505+
server
1506+
.stats()
1507+
.transaction(self.server_parameters.get_application_name());
14991508

15001509
// Release server back to the pool if we are in transaction mode.
15011510
// If we are in session mode, we keep the server until the client disconnects.
@@ -1547,7 +1556,9 @@ where
15471556

15481557
Err(Error::ClientError(format!(
15491558
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
1550-
self.pool_name, self.username, self.application_name
1559+
self.pool_name,
1560+
self.username,
1561+
self.server_parameters.get_application_name()
15511562
)))
15521563
}
15531564
}
@@ -1704,7 +1715,7 @@ where
17041715
client_stats.query();
17051716
server.stats().query(
17061717
Instant::now().duration_since(query_start).as_millis() as u64,
1707-
&self.application_name,
1718+
&self.server_parameters.get_application_name(),
17081719
);
17091720

17101721
Ok(())
@@ -1733,38 +1744,18 @@ where
17331744
pool: &ConnectionPool,
17341745
client_stats: &ClientStats,
17351746
) -> Result<BytesMut, Error> {
1736-
if pool.settings.user.statement_timeout > 0 {
1737-
match tokio::time::timeout(
1738-
tokio::time::Duration::from_millis(pool.settings.user.statement_timeout),
1739-
server.recv(),
1740-
)
1741-
.await
1742-
{
1743-
Ok(result) => match result {
1744-
Ok(message) => Ok(message),
1745-
Err(err) => {
1746-
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
1747-
error_response_terminal(
1748-
&mut self.write,
1749-
&format!("error receiving data from server: {:?}", err),
1750-
)
1751-
.await?;
1752-
Err(err)
1753-
}
1754-
},
1755-
Err(_) => {
1756-
error!(
1757-
"Statement timeout while talking to {:?} with user {}",
1758-
address, pool.settings.user.username
1759-
);
1760-
server.mark_bad();
1761-
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
1762-
error_response_terminal(&mut self.write, "pool statement timeout").await?;
1763-
Err(Error::StatementTimeout)
1764-
}
1765-
}
1766-
} else {
1767-
match server.recv().await {
1747+
let statement_timeout_duration = match pool.settings.user.statement_timeout {
1748+
0 => tokio::time::Duration::MAX,
1749+
timeout => tokio::time::Duration::from_millis(timeout),
1750+
};
1751+
1752+
match tokio::time::timeout(
1753+
statement_timeout_duration,
1754+
server.recv(Some(&mut self.server_parameters)),
1755+
)
1756+
.await
1757+
{
1758+
Ok(result) => match result {
17681759
Ok(message) => Ok(message),
17691760
Err(err) => {
17701761
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
@@ -1775,6 +1766,16 @@ where
17751766
.await?;
17761767
Err(err)
17771768
}
1769+
},
1770+
Err(_) => {
1771+
error!(
1772+
"Statement timeout while talking to {:?} with user {}",
1773+
address, pool.settings.user.username
1774+
);
1775+
server.mark_bad();
1776+
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
1777+
error_response_terminal(&mut self.write, "pool statement timeout").await?;
1778+
Err(Error::StatementTimeout)
17781779
}
17791780
}
17801781
}

src/messages.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ where
144144
bytes.put_slice(user.as_bytes());
145145
bytes.put_u8(0);
146146

147+
// Application name
148+
bytes.put(&b"application_name\0"[..]);
149+
bytes.put_slice(&b"pgcat\0"[..]);
150+
147151
// Database
148152
bytes.put(&b"database\0"[..]);
149153
bytes.put_slice(database.as_bytes());
@@ -731,6 +735,21 @@ impl BytesMutReader for Cursor<&BytesMut> {
731735
}
732736
}
733737

738+
impl BytesMutReader for BytesMut {
739+
/// Should only be used when reading strings from the message protocol.
740+
/// Can be used to read multiple strings from the same message which are separated by the null byte
741+
fn read_string(&mut self) -> Result<String, Error> {
742+
let null_index = self.iter().position(|&byte| byte == b'\0');
743+
744+
match null_index {
745+
Some(index) => {
746+
let string_bytes = self.split_to(index + 1);
747+
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
748+
}
749+
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
750+
}
751+
}
752+
}
734753
/// Parse (F) message.
735754
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
736755
#[derive(Clone, Debug)]

src/mirrors.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ impl MirroredClient {
7878
}
7979

8080
// Incoming data from server (we read to clear the socket buffer and discard the data)
81-
recv_result = server.recv() => {
81+
recv_result = server.recv(None) => {
8282
match recv_result {
8383
Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()),
8484
Err(err) => {

src/pool.rs

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use arc_swap::ArcSwap;
22
use async_trait::async_trait;
33
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
4-
use bytes::{BufMut, BytesMut};
54
use chrono::naive::NaiveDateTime;
65
use log::{debug, error, info, warn};
76
use once_cell::sync::Lazy;
@@ -25,7 +24,7 @@ use crate::errors::Error;
2524

2625
use crate::auth_passthrough::AuthPassthrough;
2726
use crate::plugins::prewarmer;
28-
use crate::server::Server;
27+
use crate::server::{Server, ServerParameters};
2928
use crate::sharding::ShardingFunction;
3029
use crate::stats::{AddressStats, ClientStats, ServerStats};
3130

@@ -196,10 +195,10 @@ pub struct ConnectionPool {
196195
/// that should not be queried.
197196
banlist: BanList,
198197

199-
/// The server information (K messages) have to be passed to the
198+
/// The server information has to be passed to the
200199
/// clients on startup. We pre-connect to all shards and replicas
201-
/// on pool creation and save the K messages here.
202-
server_info: Arc<RwLock<BytesMut>>,
200+
/// on pool creation and save the startup parameters here.
201+
original_server_parameters: Arc<RwLock<ServerParameters>>,
203202

204203
/// Pool configuration.
205204
pub settings: PoolSettings,
@@ -445,7 +444,7 @@ impl ConnectionPool {
445444
addresses,
446445
banlist: Arc::new(RwLock::new(banlist)),
447446
config_hash: new_pool_hash_value,
448-
server_info: Arc::new(RwLock::new(BytesMut::new())),
447+
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
449448
auth_hash: pool_auth_hash,
450449
settings: PoolSettings {
451450
pool_mode: match user.pool_mode {
@@ -528,7 +527,7 @@ impl ConnectionPool {
528527
for server in 0..self.servers(shard) {
529528
let databases = self.databases.clone();
530529
let validated = Arc::clone(&validated);
531-
let pool_server_info = Arc::clone(&self.server_info);
530+
let pool_server_parameters = Arc::clone(&self.original_server_parameters);
532531

533532
let task = tokio::task::spawn(async move {
534533
let connection = match databases[shard][server].get().await {
@@ -541,11 +540,10 @@ impl ConnectionPool {
541540

542541
let proxy = connection;
543542
let server = &*proxy;
544-
let server_info = server.server_info();
543+
let server_parameters: ServerParameters = server.server_parameters();
545544

546-
let mut guard = pool_server_info.write();
547-
guard.clear();
548-
guard.put(server_info.clone());
545+
let mut guard = pool_server_parameters.write();
546+
*guard = server_parameters;
549547
validated.store(true, Ordering::Relaxed);
550548
});
551549

@@ -557,7 +555,7 @@ impl ConnectionPool {
557555

558556
// TODO: compare server information to make sure
559557
// all shards are running identical configurations.
560-
if self.server_info.read().is_empty() {
558+
if !self.validated() {
561559
error!("Could not validate connection pool");
562560
return Err(Error::AllServersDown);
563561
}
@@ -917,8 +915,8 @@ impl ConnectionPool {
917915
&self.addresses[shard][server]
918916
}
919917

920-
pub fn server_info(&self) -> BytesMut {
921-
self.server_info.read().clone()
918+
pub fn server_parameters(&self) -> ServerParameters {
919+
self.original_server_parameters.read().clone()
922920
}
923921

924922
fn busy_connection_count(&self, address: &Address) -> u32 {

0 commit comments

Comments
 (0)