@@ -12,7 +12,7 @@ use tokio::net::TcpStream;
12
12
use tokio:: sync:: broadcast:: Receiver ;
13
13
use tokio:: sync:: mpsc:: Sender ;
14
14
15
- use crate :: admin:: { generate_server_info_for_admin , handle_admin} ;
15
+ use crate :: admin:: { generate_server_parameters_for_admin , handle_admin} ;
16
16
use crate :: auth_passthrough:: refetch_auth_hash;
17
17
use crate :: config:: {
18
18
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address , PoolMode ,
@@ -22,7 +22,7 @@ use crate::messages::*;
22
22
use crate :: plugins:: PluginOutput ;
23
23
use crate :: pool:: { get_pool, ClientServerMap , ConnectionPool } ;
24
24
use crate :: query_router:: { Command , QueryRouter } ;
25
- use crate :: server:: Server ;
25
+ use crate :: server:: { Server , ServerParameters } ;
26
26
use crate :: stats:: { ClientStats , ServerStats } ;
27
27
use crate :: tls:: Tls ;
28
28
@@ -96,8 +96,8 @@ pub struct Client<S, T> {
96
96
/// Postgres user for this client (This comes from the user in the connection string)
97
97
username : String ,
98
98
99
- /// Application name for this client (defaults to pgcat)
100
- application_name : String ,
99
+ /// Server startup and session parameters that we're going to track
100
+ server_parameters : ServerParameters ,
101
101
102
102
/// Used to notify clients about an impending shutdown
103
103
shutdown : Receiver < ( ) > ,
@@ -502,7 +502,7 @@ where
502
502
} ;
503
503
504
504
// Authenticate admin user.
505
- let ( transaction_mode, server_info ) = if admin {
505
+ let ( transaction_mode, mut server_parameters ) = if admin {
506
506
let config = get_config ( ) ;
507
507
508
508
// Compare server and client hashes.
@@ -521,7 +521,7 @@ where
521
521
return Err ( error) ;
522
522
}
523
523
524
- ( false , generate_server_info_for_admin ( ) )
524
+ ( false , generate_server_parameters_for_admin ( ) )
525
525
}
526
526
// Authenticate normal user.
527
527
else {
@@ -654,13 +654,16 @@ where
654
654
}
655
655
}
656
656
657
- ( transaction_mode, pool. server_info ( ) )
657
+ ( transaction_mode, pool. server_parameters ( ) )
658
658
} ;
659
659
660
+ // Update the parameters to merge what the application sent and what's originally on the server
661
+ server_parameters. set_from_hashmap ( & parameters, false ) ;
662
+
660
663
debug ! ( "Password authentication successful" ) ;
661
664
662
665
auth_ok ( & mut write) . await ?;
663
- write_all ( & mut write, server_info ) . await ?;
666
+ write_all ( & mut write, ( & server_parameters ) . into ( ) ) . await ?;
664
667
backend_key_data ( & mut write, process_id, secret_key) . await ?;
665
668
ready_for_query ( & mut write) . await ?;
666
669
@@ -690,7 +693,7 @@ where
690
693
last_server_stats : None ,
691
694
pool_name : pool_name. clone ( ) ,
692
695
username : username. clone ( ) ,
693
- application_name : application_name . to_string ( ) ,
696
+ server_parameters ,
694
697
shutdown,
695
698
connected_to_server : false ,
696
699
prepared_statements : HashMap :: new ( ) ,
@@ -725,7 +728,7 @@ where
725
728
last_server_stats : None ,
726
729
pool_name : String :: from ( "undefined" ) ,
727
730
username : String :: from ( "undefined" ) ,
728
- application_name : String :: from ( "undefined" ) ,
731
+ server_parameters : ServerParameters :: new ( ) ,
729
732
shutdown,
730
733
connected_to_server : false ,
731
734
prepared_statements : HashMap :: new ( ) ,
@@ -774,8 +777,11 @@ where
774
777
let mut prepared_statement = None ;
775
778
let mut will_prepare = false ;
776
779
777
- let client_identifier =
778
- ClientIdentifier :: new ( & self . application_name , & self . username , & self . pool_name ) ;
780
+ let client_identifier = ClientIdentifier :: new (
781
+ & self . server_parameters . get_application_name ( ) ,
782
+ & self . username ,
783
+ & self . pool_name ,
784
+ ) ;
779
785
780
786
// Our custom protocol loop.
781
787
// We expect the client to either start a transaction with regular queries
@@ -1115,10 +1121,7 @@ where
1115
1121
server. address( )
1116
1122
) ;
1117
1123
1118
- // TODO: investigate other parameters and set them too.
1119
-
1120
- // Set application_name.
1121
- server. set_name ( & self . application_name ) . await ?;
1124
+ server. sync_parameters ( & self . server_parameters ) . await ?;
1122
1125
1123
1126
let mut initial_message = Some ( message) ;
1124
1127
@@ -1296,7 +1299,9 @@ where
1296
1299
if !server. in_transaction ( ) {
1297
1300
// Report transaction executed statistics.
1298
1301
self . stats . transaction ( ) ;
1299
- server. stats ( ) . transaction ( & self . application_name ) ;
1302
+ server
1303
+ . stats ( )
1304
+ . transaction ( & self . server_parameters . get_application_name ( ) ) ;
1300
1305
1301
1306
// Release server back to the pool if we are in transaction mode.
1302
1307
// If we are in session mode, we keep the server until the client disconnects.
@@ -1446,7 +1451,9 @@ where
1446
1451
1447
1452
if !server. in_transaction ( ) {
1448
1453
self . stats . transaction ( ) ;
1449
- server. stats ( ) . transaction ( & self . application_name ) ;
1454
+ server
1455
+ . stats ( )
1456
+ . transaction ( & self . server_parameters . get_application_name ( ) ) ;
1450
1457
1451
1458
// Release server back to the pool if we are in transaction mode.
1452
1459
// If we are in session mode, we keep the server until the client disconnects.
@@ -1495,7 +1502,9 @@ where
1495
1502
1496
1503
if !server. in_transaction ( ) {
1497
1504
self . stats . transaction ( ) ;
1498
- server. stats ( ) . transaction ( & self . application_name ) ;
1505
+ server
1506
+ . stats ( )
1507
+ . transaction ( self . server_parameters . get_application_name ( ) ) ;
1499
1508
1500
1509
// Release server back to the pool if we are in transaction mode.
1501
1510
// If we are in session mode, we keep the server until the client disconnects.
@@ -1547,7 +1556,9 @@ where
1547
1556
1548
1557
Err ( Error :: ClientError ( format ! (
1549
1558
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}" ,
1550
- self . pool_name, self . username, self . application_name
1559
+ self . pool_name,
1560
+ self . username,
1561
+ self . server_parameters. get_application_name( )
1551
1562
) ) )
1552
1563
}
1553
1564
}
@@ -1704,7 +1715,7 @@ where
1704
1715
client_stats. query ( ) ;
1705
1716
server. stats ( ) . query (
1706
1717
Instant :: now ( ) . duration_since ( query_start) . as_millis ( ) as u64 ,
1707
- & self . application_name ,
1718
+ & self . server_parameters . get_application_name ( ) ,
1708
1719
) ;
1709
1720
1710
1721
Ok ( ( ) )
@@ -1733,38 +1744,18 @@ where
1733
1744
pool : & ConnectionPool ,
1734
1745
client_stats : & ClientStats ,
1735
1746
) -> Result < BytesMut , Error > {
1736
- if pool. settings . user . statement_timeout > 0 {
1737
- match tokio:: time:: timeout (
1738
- tokio:: time:: Duration :: from_millis ( pool. settings . user . statement_timeout ) ,
1739
- server. recv ( ) ,
1740
- )
1741
- . await
1742
- {
1743
- Ok ( result) => match result {
1744
- Ok ( message) => Ok ( message) ,
1745
- Err ( err) => {
1746
- pool. ban ( address, BanReason :: MessageReceiveFailed , Some ( client_stats) ) ;
1747
- error_response_terminal (
1748
- & mut self . write ,
1749
- & format ! ( "error receiving data from server: {:?}" , err) ,
1750
- )
1751
- . await ?;
1752
- Err ( err)
1753
- }
1754
- } ,
1755
- Err ( _) => {
1756
- error ! (
1757
- "Statement timeout while talking to {:?} with user {}" ,
1758
- address, pool. settings. user. username
1759
- ) ;
1760
- server. mark_bad ( ) ;
1761
- pool. ban ( address, BanReason :: StatementTimeout , Some ( client_stats) ) ;
1762
- error_response_terminal ( & mut self . write , "pool statement timeout" ) . await ?;
1763
- Err ( Error :: StatementTimeout )
1764
- }
1765
- }
1766
- } else {
1767
- match server. recv ( ) . await {
1747
+ let statement_timeout_duration = match pool. settings . user . statement_timeout {
1748
+ 0 => tokio:: time:: Duration :: MAX ,
1749
+ timeout => tokio:: time:: Duration :: from_millis ( timeout) ,
1750
+ } ;
1751
+
1752
+ match tokio:: time:: timeout (
1753
+ statement_timeout_duration,
1754
+ server. recv ( Some ( & mut self . server_parameters ) ) ,
1755
+ )
1756
+ . await
1757
+ {
1758
+ Ok ( result) => match result {
1768
1759
Ok ( message) => Ok ( message) ,
1769
1760
Err ( err) => {
1770
1761
pool. ban ( address, BanReason :: MessageReceiveFailed , Some ( client_stats) ) ;
@@ -1775,6 +1766,16 @@ where
1775
1766
. await ?;
1776
1767
Err ( err)
1777
1768
}
1769
+ } ,
1770
+ Err ( _) => {
1771
+ error ! (
1772
+ "Statement timeout while talking to {:?} with user {}" ,
1773
+ address, pool. settings. user. username
1774
+ ) ;
1775
+ server. mark_bad ( ) ;
1776
+ pool. ban ( address, BanReason :: StatementTimeout , Some ( client_stats) ) ;
1777
+ error_response_terminal ( & mut self . write , "pool statement timeout" ) . await ?;
1778
+ Err ( Error :: StatementTimeout )
1778
1779
}
1779
1780
}
1780
1781
}
0 commit comments