diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index 37acade..7140035 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -91,8 +91,9 @@ where let mut wait_time_start = None; let future = async { + let getting = self.inner.start_get(); loop { - let (conn, approvals) = self.inner.pop(); + let (conn, approvals) = getting.get(); self.spawn_replenishing_approvals(approvals); // Cancellation safety: make sure to wrap the connection in a `PooledConnection` diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 5cac9f6..ed8aa41 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -36,17 +36,6 @@ where } } - pub(crate) fn pop(&self) -> (Option>, ApprovalIter) { - let mut locked = self.internals.lock(); - let conn = locked.conns.pop_front().map(|idle| idle.conn); - let approvals = match &conn { - Some(_) => locked.wanted(&self.statics), - None => locked.approvals(&self.statics, 1), - }; - - (conn, approvals) - } - pub(crate) fn try_put(self: &Arc, conn: M::Connection) -> Result<(), M::Connection> { let mut locked = self.internals.lock(); let mut approvals = locked.approvals(&self.statics, 1); @@ -67,6 +56,10 @@ where iter } + pub(crate) fn start_get(self: &Arc) -> Getting { + Getting::new(self.clone()) + } + pub(crate) fn forward_error(&self, err: M::Error) { self.statics.error_sink.sink(err); } @@ -81,6 +74,7 @@ where conns: VecDeque>, num_conns: u32, pending_conns: u32, + in_flight: u32, } impl PoolInternals @@ -202,6 +196,7 @@ where conns: VecDeque::new(), num_conns: 0, pending_conns: 0, + in_flight: 0, } } } @@ -236,6 +231,43 @@ pub(crate) struct Approval { _priv: (), } +pub(crate) struct Getting { + inner: Arc>, +} + +impl Getting { + pub(crate) fn get(&self) -> (Option>, ApprovalIter) { + let mut locked = self.inner.internals.lock(); + if let Some(IdleConn { conn, .. }) = locked.conns.pop_front() { + return (Some(conn), locked.wanted(&self.inner.statics)); + } + + let approvals = match locked.in_flight > locked.pending_conns { + true => 1, + false => 0, + }; + + (None, locked.approvals(&self.inner.statics, approvals)) + } +} + +impl Getting { + fn new(inner: Arc>) -> Self { + { + let mut locked = inner.internals.lock(); + locked.in_flight += 1; + } + Getting { inner } + } +} + +impl Drop for Getting { + fn drop(&mut self) { + let mut locked = self.inner.internals.lock(); + locked.in_flight -= 1; + } +} + #[derive(Debug)] pub(crate) struct Conn where diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 724a376..8680309 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -1103,3 +1103,37 @@ async fn test_add_checks_broken_connections() { let res = pool.add(conn); assert!(matches!(res, Err(AddError::Broken(_)))); } + +#[tokio::test] +async fn test_reuse_on_drop() { + let pool = Pool::builder() + .min_idle(0) + .max_size(100) + .queue_strategy(QueueStrategy::Lifo) + .build(OkManager::::new()) + .await + .unwrap(); + + // The first get should + // 1) see nothing in the pool, + // 2) spawn a single replenishing approval, + // 3) get notified of the new connection and grab it from the pool + let conn_0 = pool.get().await.expect("should connect"); + // Dropping the connection queues up a notify + drop(conn_0); + + // The second get should + // 1) see the first connection in the pool and grab it + let _conn_1 = pool.get().await.expect("should connect"); + + // The third get will + // 1) see nothing in the pool, + // 2) spawn a single replenishing approval, + // 3) get notified of the new connection, + // 4) see nothing in the pool, + // 5) _not_ spawn a single replenishing approval, + // 6) get notified of the new connection and grab it from the pool + let _conn_2 = pool.get().await.expect("should connect"); + + assert_eq!(pool.state().connections, 2); +}