diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index 37acade..fa56613 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -91,6 +91,7 @@ where let mut wait_time_start = None; let future = async { + let _getting = self.inner.start_get(); loop { let (conn, approvals) = self.inner.pop(); self.spawn_replenishing_approvals(approvals); diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 5cac9f6..249af7d 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -38,13 +38,16 @@ where pub(crate) fn pop(&self) -> (Option>, ApprovalIter) { let mut locked = self.internals.lock(); - let conn = locked.conns.pop_front().map(|idle| idle.conn); - let approvals = match &conn { - Some(_) => locked.wanted(&self.statics), - None => locked.approvals(&self.statics, 1), + if let Some(IdleConn { conn, .. }) = locked.conns.pop_front() { + return (Some(conn), locked.wanted(&self.statics)); + } + + let approvals = match locked.in_flight > locked.pending_conns { + true => 1, + false => 0, }; - (conn, approvals) + (None, locked.approvals(&self.statics, approvals)) } pub(crate) fn try_put(self: &Arc, conn: M::Connection) -> Result<(), M::Connection> { @@ -67,6 +70,10 @@ where iter } + pub(crate) fn start_get(self: &Arc) -> Getting { + Getting::new(self.clone()) + } + pub(crate) fn forward_error(&self, err: M::Error) { self.statics.error_sink.sink(err); } @@ -81,6 +88,7 @@ where conns: VecDeque>, num_conns: u32, pending_conns: u32, + in_flight: u32, } impl PoolInternals @@ -202,6 +210,7 @@ where conns: VecDeque::new(), num_conns: 0, pending_conns: 0, + in_flight: 0, } } } @@ -236,6 +245,27 @@ pub(crate) struct Approval { _priv: (), } +pub(crate) struct Getting { + inner: Arc>, +} + +impl Getting { + fn new(inner: Arc>) -> Self { + { + let mut locked = inner.internals.lock(); + locked.in_flight += 1; + } + Getting { inner } + } +} + +impl Drop for Getting { + fn drop(&mut self) { + let mut locked = self.inner.internals.lock(); + locked.in_flight -= 1; + } +} + #[derive(Debug)] pub(crate) struct Conn where diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 724a376..c46fcc3 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -1103,3 +1103,37 @@ async fn test_add_checks_broken_connections() { let res = pool.add(conn); assert!(matches!(res, Err(AddError::Broken(_)))); } + +#[tokio::test] +async fn test_reuse_on_drop() { + let pool = Pool::builder() + .min_idle(0) + .max_size(100) + .queue_strategy(QueueStrategy::Lifo) + .build(OkManager::::new()) + .await + .unwrap(); + + // The first get should + // 1) see nothing in the pool, + // 2) spawn a single replenishing approval, + // 3) get notified of the new connection and grab it from the pool + let conn_0 = pool.get().await.expect("should connect"); + // Dropping the connection queues up a notify + drop(conn_0); + // The second get should + // 1) see the first connection in the pool and grab it + let _conn_1: PooledConnection> = + pool.get().await.expect("should connect"); + // The third get will + // 1) see nothing in the pool, + // 2) spawn a single replenishing approval, + // 3) get notified of the new connection, + // 4) see nothing in the pool, + // 5) _not_ spawn a single replenishing approval, + // 6) get notified of the new connection and grab it from the pool + let _conn_2: PooledConnection> = + pool.get().await.expect("should connect"); + + assert_eq!(pool.state().connections, 2); +}