diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index a258090b1582..df716f8455f0 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -776,6 +776,7 @@ impl From<&jose_jwk::Key> for KeyType { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use std::future::IntoFuture; use std::net::SocketAddr; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index f38ecf715f7d..50cb94bfa069 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -463,6 +463,8 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { #[cfg(test)] mod tests { + #![allow(clippy::unimplemented, clippy::unwrap_used)] + use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index f6bce9f2d8aa..eff49a402aaa 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -250,6 +250,7 @@ fn project_name_valid(name: &str) -> bool { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use serde_json::json; use ComputeUserInfoParseError::*; diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs index 20db1fbb147a..0136446d6dfb 100644 --- a/proxy/src/cache/endpoints.rs +++ b/proxy/src/cache/endpoints.rs @@ -12,6 +12,7 @@ use tracing::info; use crate::config::EndpointCacheConfig; use crate::context::RequestContext; +use crate::ext::LockExt; use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; use crate::rate_limiter::GlobalRateLimiter; @@ -96,7 +97,7 @@ impl EndpointsCache { // If the limiter allows, we can pretend like it's valid // (incase it is, due to redis channel lag). - if self.limiter.lock().unwrap().check() { + if self.limiter.lock_propagate_poison().check() { return true; } @@ -258,6 +259,7 @@ impl EndpointsCache { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::*; diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 84430dc812e9..cab0b8b90594 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -365,6 +365,7 @@ impl Cache for ProjectInfoCacheImpl { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::*; use crate::scram::ServerSecret; diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index ed717507ee40..dd3edd6abc8f 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -13,6 +13,7 @@ use uuid::Uuid; use crate::auth::{check_peer_addr_is_in_list, IpPattern}; use crate::error::ReportableError; +use crate::ext::LockExt; use crate::metrics::{CancellationRequest, CancellationSource, Metrics}; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::cancellation_publisher::{ @@ -114,7 +115,7 @@ impl CancellationHandler

{ IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; - if !self.limiter.lock().unwrap().check(subnet_key, 1) { + if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { // log only the subnet part of the IP address to know which subnet is rate limited tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); Metrics::get() @@ -283,6 +284,7 @@ impl

Drop for Session

{ } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::*; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 8bc8e3f96f59..1f991d595eca 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -221,15 +221,10 @@ impl CertResolver { ) -> anyhow::Result<()> { let priv_key = { let key_bytes = std::fs::read(key_path) - .context(format!("Failed to read TLS keys at '{key_path}'"))?; - let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec(); - - ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); - PrivateKeyDer::Pkcs8( - keys.pop() - .unwrap() - .context(format!("Failed to parse TLS keys at '{key_path}'"))?, - ) + .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?; + rustls_pemfile::private_key(&mut &key_bytes[..]) + .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? + .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? }; let cert_chain_bytes = std::fs::read(cert_path) diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 3105d085260d..5f65b17374f2 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -23,6 +23,7 @@ use utils::backoff; use super::{RequestContextInner, LOG_CHAN}; use crate::config::remote_storage_from_toml; use crate::context::LOG_CHAN_DISCONNECT; +use crate::ext::TaskExt; #[derive(clap::Args, Clone, Debug)] pub struct ParquetUploadArgs { @@ -171,7 +172,9 @@ pub async fn worker( }; let (tx, mut rx) = mpsc::unbounded_channel(); - LOG_CHAN.set(tx.downgrade()).unwrap(); + LOG_CHAN + .set(tx.downgrade()) + .expect("only one worker should set the channel"); // setup row stream that will close on cancellation let cancellation_token2 = cancellation_token.clone(); @@ -207,7 +210,9 @@ pub async fn worker( config.parquet_upload_disconnect_events_remote_storage { let (tx_disconnect, mut rx_disconnect) = mpsc::unbounded_channel(); - LOG_CHAN_DISCONNECT.set(tx_disconnect.downgrade()).unwrap(); + LOG_CHAN_DISCONNECT + .set(tx_disconnect.downgrade()) + .expect("only one worker should set the channel"); // setup row stream that will close on cancellation tokio::spawn(async move { @@ -326,7 +331,7 @@ where Ok::<_, parquet::errors::ParquetError>((rows, w, rg_meta)) }) .await - .unwrap()?; + .propagate_task_panic()?; rows.clear(); Ok((rows, w, rg_meta)) @@ -352,7 +357,7 @@ async fn upload_parquet( Ok((buffer, metadata)) }) .await - .unwrap()?; + .propagate_task_panic()?; let data = buffer.split().freeze(); @@ -409,6 +414,7 @@ async fn upload_parquet( } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use std::net::Ipv4Addr; use std::num::NonZeroUsize; diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 93edd6547693..5f8bda0f35ae 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -102,7 +102,9 @@ impl MockControlPlane { Some(s) => { info!("got allowed_ips: {s}"); s.split(',') - .map(|s| IpPattern::from_str(s).unwrap()) + .map(|s| { + IpPattern::from_str(s).expect("mocked ip pattern should be correct") + }) .collect() } None => vec![], diff --git a/proxy/src/ext.rs b/proxy/src/ext.rs new file mode 100644 index 000000000000..8d00afbf51a4 --- /dev/null +++ b/proxy/src/ext.rs @@ -0,0 +1,41 @@ +use std::panic::resume_unwind; +use std::sync::{Mutex, MutexGuard}; + +use tokio::task::JoinError; + +pub(crate) trait LockExt { + fn lock_propagate_poison(&self) -> MutexGuard<'_, T>; +} + +impl LockExt for Mutex { + /// Lock the mutex and panic if the mutex was poisoned. + #[track_caller] + fn lock_propagate_poison(&self) -> MutexGuard<'_, T> { + match self.lock() { + Ok(guard) => guard, + // poison occurs when another thread panicked while holding the lock guard. + // since panicking is often unrecoverable, propagating the poison panic is reasonable. + Err(poison) => panic!("{poison}"), + } + } +} + +pub(crate) trait TaskExt { + fn propagate_task_panic(self) -> T; +} + +impl TaskExt for Result { + /// Unwrap the result and panic if the inner task panicked. + /// Also panics if the task was cancelled + #[track_caller] + fn propagate_task_panic(self) -> T { + match self { + Ok(t) => t, + // Using resume_unwind prevents the panic hook being called twice. + // Since we use this for structured concurrency, there is only + // 1 logical panic, so this is more correct. + Err(e) if e.is_panic() => resume_unwind(e.into_panic()), + Err(e) => panic!("unexpected task error: {e}"), + } + } +} diff --git a/proxy/src/http/health_server.rs b/proxy/src/http/health_server.rs index 978ad9f76131..6ca091feb716 100644 --- a/proxy/src/http/health_server.rs +++ b/proxy/src/http/health_server.rs @@ -14,6 +14,7 @@ use utils::http::error::ApiError; use utils::http::json::json_response; use utils::http::{RouterBuilder, RouterService}; +use crate::ext::{LockExt, TaskExt}; use crate::jemalloc; async fn status_handler(_: Request) -> Result, ApiError> { @@ -76,7 +77,7 @@ async fn prometheus_metrics_handler( let body = tokio::task::spawn_blocking(move || { let _span = span.entered(); - let mut state = state.lock().unwrap(); + let mut state = state.lock_propagate_poison(); let PrometheusHandler { encoder, metrics } = &mut *state; metrics @@ -94,13 +95,13 @@ async fn prometheus_metrics_handler( body }) .await - .unwrap(); + .propagate_task_panic(); let response = Response::builder() .status(200) .header(CONTENT_TYPE, "text/plain; version=0.0.4") .body(Body::from(body)) - .unwrap(); + .expect("response headers should be valid"); Ok(response) } diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index f56d92a6b31e..79c6020302af 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -83,7 +83,7 @@ impl StringInterner { pub(crate) fn new() -> Self { StringInterner { inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher( - Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()), + Capacity::new(2500, NonZeroUsize::new(1 << 16).expect("value is nonzero")), // unbounded MemoryLimits::for_memory_usage(usize::MAX), BuildHasherDefault::::default(), @@ -207,6 +207,7 @@ impl From for ProjectIdInt { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use std::sync::OnceLock; diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index ba69f9cf2d28..a5a72f26d950 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -22,8 +22,8 @@ clippy::string_add, clippy::string_to_string, clippy::todo, - // TODO: consider clippy::unimplemented - // TODO: consider clippy::unwrap_used + clippy::unimplemented, + clippy::unwrap_used, )] // List of permanently allowed lints. #![allow( @@ -82,6 +82,7 @@ pub mod console_redirect_proxy; pub mod context; pub mod control_plane; pub mod error; +mod ext; pub mod http; pub mod intern; pub mod jemalloc; diff --git a/proxy/src/logging.rs b/proxy/src/logging.rs index 74d2b9a1d01e..41f10f052ffa 100644 --- a/proxy/src/logging.rs +++ b/proxy/src/logging.rs @@ -18,8 +18,16 @@ pub async fn init() -> anyhow::Result { let env_filter = EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy() - .add_directive("aws_config=info".parse().unwrap()) - .add_directive("azure_core::policies::transport=off".parse().unwrap()); + .add_directive( + "aws_config=info" + .parse() + .expect("this should be a valid filter directive"), + ) + .add_directive( + "azure_core::policies::transport=off" + .parse() + .expect("this should be a valid filter directive"), + ); let fmt_layer = tracing_subscriber::fmt::layer() .with_ansi(false) diff --git a/proxy/src/parse.rs b/proxy/src/parse.rs index 8c0f25106662..095d6278cc51 100644 --- a/proxy/src/parse.rs +++ b/proxy/src/parse.rs @@ -8,14 +8,6 @@ pub(crate) fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { Some((cstr, other)) } -/// See . -pub(crate) fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { - (bytes.len() >= N).then(|| { - let (head, tail) = bytes.split_at(N); - (head.try_into().unwrap(), tail) - }) -} - #[cfg(test)] mod tests { use super::*; @@ -33,11 +25,4 @@ mod tests { assert_eq!(cstr.to_bytes(), b"foo"); assert_eq!(rest, b"bar"); } - - #[test] - fn test_split_at_const() { - assert!(split_at_const::<0>(b"").is_some()); - assert!(split_at_const::<1>(b"").is_none()); - assert!(matches!(split_at_const::<1>(b"ok"), Some((b"o", b"k")))); - } } diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 33a5eb5e1e03..0dc97b709724 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -396,6 +396,7 @@ impl NetworkEndianIpv6 { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use tokio::io::AsyncReadExt; diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 4e4af8863484..3336a9556a5b 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -257,6 +257,7 @@ impl CopyBuffer { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use tokio::io::AsyncWriteExt; diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 5db92d748a5b..4e5ecda237d5 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -488,7 +488,7 @@ impl NeonOptions { pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> { static RE: OnceCell = OnceCell::new(); - let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap()); + let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").expect("regex should be correct")); let cap = re.captures(bytes)?; let (_, [k, v]) = cap.extract(); diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3899ba4267ca..95c518fed9c2 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -1,4 +1,5 @@ //! A group of high-level tests for connection establishing logic and auth. +#![allow(clippy::unimplemented, clippy::unwrap_used)] mod mitm; diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 45f9630dde0f..bff800f0a2f0 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -83,7 +83,7 @@ impl From for utils::leaky_bucket::LeakyBucketConfig { } #[cfg(test)] -#[allow(clippy::float_cmp)] +#[allow(clippy::float_cmp, clippy::unwrap_used)] mod tests { use std::time::Duration; diff --git a/proxy/src/rate_limiter/limit_algorithm/aimd.rs b/proxy/src/rate_limiter/limit_algorithm/aimd.rs index 3000cc4c2af2..04e136b6d543 100644 --- a/proxy/src/rate_limiter/limit_algorithm/aimd.rs +++ b/proxy/src/rate_limiter/limit_algorithm/aimd.rs @@ -63,6 +63,7 @@ impl LimitAlgorithm for Aimd { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use std::time::Duration; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index a048721e77d7..6f6a8c9d4781 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -12,6 +12,7 @@ use rand::{Rng, SeedableRng}; use tokio::time::{Duration, Instant}; use tracing::info; +use crate::ext::LockExt; use crate::intern::EndpointIdInt; pub struct GlobalRateLimiter { @@ -246,12 +247,13 @@ impl BucketRateLimiter { let n = self.map.shards().len(); // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide // (impossible, infact, unless we have 2048 threads) - let shard = self.rand.lock().unwrap().gen_range(0..n); + let shard = self.rand.lock_propagate_poison().gen_range(0..n); self.map.shards()[shard].write().clear(); } } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use std::hash::BuildHasherDefault; use std::time::Duration; diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index 82139ea1d5e5..0f6e765b02cd 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -69,7 +69,11 @@ impl ConnectionWithCredentialsProvider { pub fn new_with_static_credentials(params: T) -> Self { Self { - credentials: Credentials::Static(params.into_connection_info().unwrap()), + credentials: Credentials::Static( + params + .into_connection_info() + .expect("static configured redis credentials should be a valid format"), + ), con: None, refresh_token_task: None, mutex: tokio::sync::Mutex::new(()), diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index 1373dfba3d9a..4922ece61531 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -2,7 +2,7 @@ use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; -use crate::parse::{split_at_const, split_cstr}; +use crate::parse::split_cstr; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). #[derive(Debug)] @@ -19,7 +19,7 @@ impl<'a> FirstMessage<'a> { let (method_cstr, tail) = split_cstr(bytes)?; let method = method_cstr.to_str().ok()?; - let (len_bytes, bytes) = split_at_const(tail)?; + let (len_bytes, bytes) = tail.split_first_chunk()?; let len = u32::from_be_bytes(*len_bytes) as usize; if len != bytes.len() { return None; @@ -51,6 +51,7 @@ impl<'a> ServerMessage<&'a str> { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::*; diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs index 5ee3a513527d..0e54e7ded9a7 100644 --- a/proxy/src/scram/messages.rs +++ b/proxy/src/scram/messages.rs @@ -185,6 +185,7 @@ impl fmt::Debug for OwnedServerFirstMessage { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::*; diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index 718445f61d48..b49a9f32eec4 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -57,6 +57,7 @@ fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::threadpool::ThreadPool; use super::{Exchange, ServerSecret}; diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 8c6a08d432d8..eb21b26ab40e 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -72,6 +72,7 @@ impl ServerSecret { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::*; diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index ebc6dd2a3cef..8f1684c75b0e 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -33,14 +33,11 @@ thread_local! { } impl ThreadPool { - pub fn new(n_workers: u8) -> Arc { + pub fn new(mut n_workers: u8) -> Arc { // rayon would be nice here, but yielding in rayon does not work well afaict. if n_workers == 0 { - return Arc::new(Self { - runtime: None, - metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)), - }); + n_workers = 1; } Arc::new_cyclic(|pool| { @@ -66,7 +63,7 @@ impl ThreadPool { }); }) .build() - .unwrap(); + .expect("password threadpool runtime should be configured correctly"); Self { runtime: Some(runtime), @@ -79,7 +76,7 @@ impl ThreadPool { JobHandle( self.runtime .as_ref() - .unwrap() + .expect("runtime is always set") .spawn(JobSpec { pbkdf2, endpoint }), ) } @@ -87,7 +84,10 @@ impl ThreadPool { impl Drop for ThreadPool { fn drop(&mut self) { - self.runtime.take().unwrap().shutdown_background(); + self.runtime + .take() + .expect("runtime is always set") + .shutdown_background(); } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 15d883bdb008..449d50b6e78b 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -268,7 +268,11 @@ impl PoolingBackend { if !self.local_pool.initialized(&conn_info) { // only install and grant usage one at a time. - let _permit = local_backend.initialize.acquire().await.unwrap(); + let _permit = local_backend + .initialize + .acquire() + .await + .expect("semaphore should never be closed"); // check again for race if !self.local_pool.initialized(&conn_info) { diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index cac5a173cb16..447103edce53 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -186,8 +186,8 @@ impl ClientDataRemote { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { - use std::mem; use std::sync::atomic::AtomicBool; use super::*; @@ -269,39 +269,33 @@ mod tests { assert_eq!(0, pool.get_global_connections_count()); } { - let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); - client.do_drop().unwrap()(); - mem::forget(client); // drop the client + let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + drop(client); assert_eq!(1, pool.get_global_connections_count()); } { - let mut closed_client = Client::new( + let closed_client = Client::new( create_inner_with(MockClient::new(true)), conn_info.clone(), ep_pool.clone(), ); - closed_client.do_drop().unwrap()(); - mem::forget(closed_client); // drop the client - // The closed client shouldn't be added to the pool. + drop(closed_client); assert_eq!(1, pool.get_global_connections_count()); } let is_closed: Arc = Arc::new(false.into()); { - let mut client = Client::new( + let client = Client::new( create_inner_with(MockClient(is_closed.clone())), conn_info.clone(), ep_pool.clone(), ); - client.do_drop().unwrap()(); - mem::forget(client); // drop the client - + drop(client); // The client should be added to the pool. assert_eq!(2, pool.get_global_connections_count()); } { - let mut client = Client::new(create_inner(), conn_info, ep_pool); - client.do_drop().unwrap()(); - mem::forget(client); // drop the client + let client = Client::new(create_inner(), conn_info, ep_pool); + drop(client); // The client shouldn't be added to the pool. Because the ep-pool is full. assert_eq!(2, pool.get_global_connections_count()); @@ -319,15 +313,13 @@ mod tests { &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()), ); { - let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); - client.do_drop().unwrap()(); - mem::forget(client); // drop the client + let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + drop(client); assert_eq!(3, pool.get_global_connections_count()); } { - let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); - client.do_drop().unwrap()(); - mem::forget(client); // drop the client + let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + drop(client); // The client shouldn't be added to the pool. Because the global pool is full. assert_eq!(3, pool.get_global_connections_count()); diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index 2a46c8f9c5cf..44eac77e8f94 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -187,19 +187,22 @@ impl EndpointConnPool { pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInnerCommon) { let conn_id = client.get_conn_id(); - let pool_name = pool.read().get_name().to_string(); + let (max_conn, conn_count, pool_name) = { + let pool = pool.read(); + ( + pool.global_pool_size_max_conns, + pool.global_connections_count + .load(atomic::Ordering::Relaxed), + pool.get_name().to_string(), + ) + }; + if client.inner.is_closed() { info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name); return; } - let global_max_conn = pool.read().global_pool_size_max_conns; - if pool - .read() - .global_connections_count - .load(atomic::Ordering::Relaxed) - >= global_max_conn - { + if conn_count >= max_conn { info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name); return; } @@ -633,35 +636,29 @@ impl Client { } pub(crate) fn metrics(&self) -> Arc { - let aux = &self.inner.as_ref().unwrap().aux; + let aux = &self + .inner + .as_ref() + .expect("client inner should not be removed") + .aux; USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, branch_id: aux.branch_id, }) } +} - pub(crate) fn do_drop(&mut self) -> Option> { +impl Drop for Client { + fn drop(&mut self) { let conn_info = self.conn_info.clone(); let client = self .inner .take() .expect("client inner should not be removed"); if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { - let current_span = self.span.clone(); + let _current_span = self.span.enter(); // return connection to the pool - return Some(move || { - let _span = current_span.enter(); - EndpointConnPool::put(&conn_pool, &conn_info, client); - }); - } - None - } -} - -impl Drop for Client { - fn drop(&mut self) { - if let Some(drop) = self.do_drop() { - tokio::task::spawn_blocking(drop); + EndpointConnPool::put(&conn_pool, &conn_info, client); } } } diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index c0208d4f68f1..d5c948777cae 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -81,11 +81,14 @@ impl HttpErrorBody { .header(http::header::CONTENT_TYPE, "application/json") // we do not have nested maps with non string keys so serialization shouldn't fail .body( - Full::new(Bytes::from(serde_json::to_string(self).unwrap())) - .map_err(|x| match x {}) - .boxed(), + Full::new(Bytes::from( + serde_json::to_string(self) + .expect("serialising HttpErrorBody should never fail"), + )) + .map_err(|x| match x {}) + .boxed(), ) - .unwrap() + .expect("content-type header should be valid") } } diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 25b25c66d3fb..ab012bd020f1 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -204,7 +204,10 @@ fn pg_array_parse_inner( if c == '\\' { escaped = true; - (i, c) = pg_array_chr.next().unwrap(); + let Some(x) = pg_array_chr.next() else { + return Err(JsonConversionError::UnbalancedArray); + }; + (i, c) = x; } match c { @@ -253,6 +256,7 @@ fn pg_array_parse_inner( } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use serde_json::json; diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index b84cde9e252a..c51a2bc9babb 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -179,7 +179,6 @@ pub(crate) fn poll_client( info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection"); }); let pool = Arc::downgrade(&global_pool); - let pool_clone = pool.clone(); let db_user = conn_info.db_and_user(); let idle = global_pool.get_idle_timeout(); @@ -273,11 +272,7 @@ pub(crate) fn poll_client( }), }; - Client::new( - inner, - conn_info, - Arc::downgrade(&pool_clone.upgrade().unwrap().global_pool), - ) + Client::new(inner, conn_info, Arc::downgrade(&global_pool.global_pool)) } impl ClientInnerCommon { @@ -321,7 +316,8 @@ fn resign_jwt(sk: &SigningKey, payload: &[u8], jti: u64) -> Result(buffer.format(jti)).unwrap(); + let jti = serde_json::from_str::<&RawValue>(buffer.format(jti)) + .expect("itoa formatted integer should be guaranteed valid json"); // update the jti in-place let payload = @@ -368,6 +364,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use p256::ecdsa::SigningKey; use typed_json::json; diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 80b42f9e5534..c2623e0ecae3 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -46,6 +46,7 @@ use utils::http::error::ApiError; use crate::cancellation::CancellationHandlerMain; use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; +use crate::ext::TaskExt; use crate::metrics::Metrics; use crate::protocol2::{read_proxy_protocol, ChainRW, ConnectHeader, ConnectionInfo}; use crate::proxy::run_until_cancelled; @@ -84,7 +85,7 @@ pub async fn task_main( cancellation_token.cancelled().await; tokio::task::spawn_blocking(move || conn_pool.shutdown()) .await - .unwrap(); + .propagate_task_panic(); } }); @@ -104,7 +105,7 @@ pub async fn task_main( cancellation_token.cancelled().await; tokio::task::spawn_blocking(move || http_conn_pool.shutdown()) .await - .unwrap(); + .propagate_task_panic(); } }); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 5e85f5ec4019..3e42787a0964 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1110,6 +1110,7 @@ impl Discard<'_> { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::*; diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index bdb83fe6be05..812fedaf0422 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -178,6 +178,7 @@ pub(crate) async fn serve_websocket( } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use std::pin::pin; diff --git a/proxy/src/url.rs b/proxy/src/url.rs index 270cd7c24da0..d73a84057ae3 100644 --- a/proxy/src/url.rs +++ b/proxy/src/url.rs @@ -50,6 +50,7 @@ impl std::fmt::Display for ApiUrl { } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use super::*; diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 65e74466f2ec..487504d709ed 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -407,6 +407,7 @@ async fn upload_backup_events( } #[cfg(test)] +#[expect(clippy::unwrap_used)] mod tests { use std::fs; use std::io::BufReader;