From d9518a8bd2ca0860675ca42bd8ac8c8817ddbf53 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 28 Aug 2024 11:02:42 +0200 Subject: [PATCH 1/3] bb8: extract BrokenConnectionManager for reuse --- bb8/tests/test.rs | 53 +++++++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index d5db34e..9e78808 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -207,6 +207,35 @@ fn test_is_send_sync() { is_send_sync::>>(); } +// A connection manager that always returns `true` for `has_broken()` +struct BrokenConnectionManager { + _c: PhantomData, +} + +impl BrokenConnectionManager { + fn new() -> Self { + BrokenConnectionManager { _c: PhantomData } + } +} + +#[async_trait] +impl ManageConnection for BrokenConnectionManager { + type Connection = C; + type Error = Error; + + async fn connect(&self) -> Result { + Ok(C::default()) + } + + async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> { + Ok(()) + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + true + } +} + #[tokio::test] async fn test_drop_on_broken() { static DROPPED: AtomicBool = AtomicBool::new(false); @@ -220,27 +249,11 @@ async fn test_drop_on_broken() { } } - struct Handler; - - #[async_trait] - impl ManageConnection for Handler { - type Connection = Connection; - type Error = Error; - - async fn connect(&self) -> Result { - Ok(Default::default()) - } - - async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> { - Ok(()) - } - - fn has_broken(&self, _: &mut Self::Connection) -> bool { - true - } - } + let pool = Pool::builder() + .build(BrokenConnectionManager::::new()) + .await + .unwrap(); - let pool = Pool::builder().build(Handler).await.unwrap(); { let _ = pool.get().await.unwrap(); } From 408fe7b8834b25ba5afce201e0f4a24c7c6167e4 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 28 Aug 2024 11:07:29 +0200 Subject: [PATCH 2/3] bb8: replace trivial new() with derived Default impl --- bb8/tests/test.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 9e78808..178e8d3 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -208,16 +208,11 @@ fn test_is_send_sync() { } // A connection manager that always returns `true` for `has_broken()` +#[derive(Default)] struct BrokenConnectionManager { _c: PhantomData, } -impl BrokenConnectionManager { - fn new() -> Self { - BrokenConnectionManager { _c: PhantomData } - } -} - #[async_trait] impl ManageConnection for BrokenConnectionManager { type Connection = C; @@ -250,7 +245,7 @@ async fn test_drop_on_broken() { } let pool = Pool::builder() - .build(BrokenConnectionManager::::new()) + .build(BrokenConnectionManager::::default()) .await .unwrap(); From 593a154f306f395ec70853eac406cabcda46c535 Mon Sep 17 00:00:00 2001 From: Taylor Neely Date: Fri, 26 Jul 2024 18:42:27 -0700 Subject: [PATCH 3/3] Add `Pool::add` Fixes #212 This adds `Pool::add`, which allows for externally created connections to be added and managed by the pool. If the pool is at maximum capacity when this method is called, it will return the input connection as part of the Err response. I considered allowing `Pool:add` to ignore `max_size` when adding to the pool, but felt it could lead to confusion if the pool is allowed to exceed its capacity in this specific case. This change required making PoolInternals::approvals visible within the crate to get the approval needed to add a new connection. The alternative would have required defining a new pub(crate) method for this specific use case, which feels worse. I'm open to suggestions on how to more cleanly integrate this change into the package. --- bb8/src/api.rs | 35 ++++++++++++++++++++++++++++++++ bb8/src/inner.rs | 13 +++++++++++- bb8/src/internals.rs | 11 ++++++++++ bb8/src/lib.rs | 2 +- bb8/tests/test.rs | 48 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 2 deletions(-) diff --git a/bb8/src/api.rs b/bb8/src/api.rs index 1f2d598..857f592 100644 --- a/bb8/src/api.rs +++ b/bb8/src/api.rs @@ -75,6 +75,14 @@ impl Pool { pub fn state(&self) -> State { self.inner.state() } + + /// Adds a connection to the pool. + /// + /// If the connection is broken, or the pool is at capacity, the + /// connection is not added and instead returned to the caller in Err. + pub fn add(&self, conn: M::Connection) -> Result<(), AddError> { + self.inner.try_put(conn) + } } /// Information about the state of a `Pool`. @@ -526,6 +534,33 @@ where } } +/// Error type returned by `Pool::add(conn)` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AddError { + /// The connection was broken before it could be added. + Broken(C), + /// Unable to add the connection to the pool due to insufficient capacity. + NoCapacity(C), +} + +impl fmt::Display for AddError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + AddError::Broken(_) => write!(f, "The connection was broken before it could be added"), + AddError::NoCapacity(_) => write!( + f, + "Unable to add the connection to the pool due to insufficient capacity" + ), + } + } +} + +impl error::Error for AddError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + None + } +} + /// A trait to receive errors generated by connection management that aren't /// tied to any particular caller. pub trait ErrorSink: fmt::Debug + Send + Sync + 'static { diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index 209d69f..be2ef52 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -9,7 +9,9 @@ use futures_util::TryFutureExt; use tokio::spawn; use tokio::time::{interval_at, sleep, timeout, Interval}; -use crate::api::{Builder, ConnectionState, ManageConnection, PooledConnection, RunError, State}; +use crate::api::{ + AddError, Builder, ConnectionState, ManageConnection, PooledConnection, RunError, State, +}; use crate::internals::{Approval, ApprovalIter, Conn, SharedPool, StatsGetKind, StatsKind}; pub(crate) struct PoolInner @@ -161,6 +163,15 @@ where } } + /// Adds an external connection to the pool if there is capacity for it. + pub(crate) fn try_put(&self, mut conn: M::Connection) -> Result<(), AddError> { + if self.inner.manager.has_broken(&mut conn) { + Err(AddError::Broken(conn)) + } else { + self.inner.try_put(conn).map_err(AddError::NoCapacity) + } + } + /// Returns information about the current state of the pool. pub(crate) fn state(&self) -> State { self.inner diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 81fefab..155e21a 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -47,6 +47,17 @@ where (conn, approvals) } + pub(crate) fn try_put(self: &Arc, conn: M::Connection) -> Result<(), M::Connection> { + let mut locked = self.internals.lock(); + let mut approvals = locked.approvals(&self.statics, 1); + let Some(approval) = approvals.next() else { + return Err(conn); + }; + let conn = Conn::new(conn); + locked.put(conn, Some(approval), self.clone()); + Ok(()) + } + pub(crate) fn reap(&self) -> ApprovalIter { let mut locked = self.internals.lock(); let (iter, closed_idle_timeout, closed_max_lifetime) = locked.reap(&self.statics); diff --git a/bb8/src/lib.rs b/bb8/src/lib.rs index 5642cba..df3de74 100644 --- a/bb8/src/lib.rs +++ b/bb8/src/lib.rs @@ -35,7 +35,7 @@ mod api; pub use api::{ - Builder, CustomizeConnection, ErrorSink, ManageConnection, NopErrorSink, Pool, + AddError, Builder, CustomizeConnection, ErrorSink, ManageConnection, NopErrorSink, Pool, PooledConnection, QueueStrategy, RunError, State, Statistics, }; diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 178e8d3..0e1225e 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -1020,3 +1020,51 @@ async fn test_statistics_connections_created() { assert_eq!(pool.state().statistics.connections_created, 1); } + +#[tokio::test] +async fn test_can_use_added_connections() { + let pool = Pool::builder() + .connection_timeout(Duration::from_millis(1)) + .build_unchecked(NthConnectionFailManager::::new(0)); + + // Assert pool can't replenish connections on its own + let res = pool.get().await; + assert_eq!(res.unwrap_err(), RunError::TimedOut); + + pool.add(FakeConnection).unwrap(); + let res = pool.get().await; + assert!(res.is_ok()); +} + +#[tokio::test] +async fn test_add_ok_until_max_size() { + let pool = Pool::builder() + .min_idle(1) + .max_size(3) + .build(OkManager::::new()) + .await + .unwrap(); + + for _ in 0..2 { + let conn = pool.dedicated_connection().await.unwrap(); + pool.add(conn).unwrap(); + } + + let conn = pool.dedicated_connection().await.unwrap(); + let res = pool.add(conn); + assert!(matches!(res, Err(AddError::NoCapacity(_)))); +} + +#[tokio::test] +async fn test_add_checks_broken_connections() { + let pool = Pool::builder() + .min_idle(1) + .max_size(3) + .build(BrokenConnectionManager::::default()) + .await + .unwrap(); + + let conn = pool.dedicated_connection().await.unwrap(); + let res = pool.add(conn); + assert!(matches!(res, Err(AddError::Broken(_)))); +}