Skip to content

Commit

Permalink
Pool connections using bb8
Browse files Browse the repository at this point in the history
  • Loading branch information
SafariMonkey committed Jun 21, 2024
1 parent 3173039 commit 0a9dbcf
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 76 deletions.
18 changes: 16 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/replicate/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ description = "A client api for state replication"
publish = false

[dependencies]
async-trait = "0.1.80"
base64.workspace = true
bb8 = "0.8.5"
bytes.workspace = true
eyre.workspace = true
futures.workspace = true
Expand Down
171 changes: 97 additions & 74 deletions crates/replicate/client/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,20 @@
use std::fmt::Debug;

use eyre::Result;
use eyre::{bail, ensure, Context, OptionExt};
use async_trait::async_trait;
use eyre::{bail, ensure, eyre, Context, OptionExt};
use eyre::{ContextCompat, Result};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use replicate_common::{
messages::manager::{Clientbound as Cb, Serverbound as Sb},
InstanceId,
};
use tokio::sync::{mpsc, oneshot};
use url::Url;

use crate::connect_to_url;
use crate::Ascii;

/// The number of queued rpc calls allowed before we start erroring.
const RPC_CAPACITY: usize = 64;

type Framed = replicate_common::Framed<wtransport::stream::BiStream, Cb, Sb>;

/// Manages instances on the instance server. Under the hood, this is all done
Expand All @@ -29,10 +26,78 @@ type Framed = replicate_common::Framed<wtransport::stream::BiStream, Cb, Sb>;
/// user IDs.
#[derive(Debug)]
pub struct Manager {
_conn: wtransport::Connection,
pool: bb8::Pool<StreamPoolManager>,
url: Url,
task: tokio::task::JoinHandle<Result<()>>,
request_tx: mpsc::Sender<(Sb, oneshot::Sender<Cb>)>,
}

#[derive(Debug)]
struct StreamPoolManager {
conn: wtransport::Connection,
}

impl StreamPoolManager {
fn new(conn: wtransport::Connection) -> Self {
Self { conn }
}
}

// bb8 returns connections to the pool even if the drop is due to a panic.
// To avoid that, we drop the inner connection if the thread is panicking.
struct DropConnectionOnPanic<'a> {
pooled_connection: bb8::PooledConnection<'a, StreamPoolManager>,
}

impl<'a> Drop for DropConnectionOnPanic<'a> {
fn drop(&mut self) {
if std::thread::panicking() {
(*self.pooled_connection).take();
}
}
}

#[async_trait]
impl bb8::ManageConnection for StreamPoolManager {
/// The connection type this manager deals with.
type Connection = Option<Framed>;
/// The error type returned by `Connection`s.
type Error = eyre::Report;
/// Attempts to create a new connection.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let bi = wtransport::stream::BiStream::join(
self.conn
.open_bi()
.await
.wrap_err("could not initiate bi stream")?
.await
.wrap_err("could not finish opening bi stream")?,
);

let framed = Framed::new(bi);
Ok(Some(framed))
}
/// Determines if the connection is still connected to the database.
async fn is_valid(&self, framed: &mut Self::Connection) -> Result<(), Self::Error> {
let framed = framed
.as_mut()
.wrap_err("connection was dropped due to panic")?;
framed
.send(Sb::HandshakeRequest)
.await
.wrap_err("failed to send handshake request")?;
let Some(msg) = framed.next().await else {
bail!("Server disconnected before completing handshake");
};
let msg = msg.wrap_err("error while receiving handshake response")?;
ensure!(
msg == Cb::HandshakeResponse,
"invalid message during handshake"
);
Ok(())
}
/// Synchronously determine if the connection is no longer usable, if possible.
fn has_broken(&self, framed: &mut Self::Connection) -> bool {
framed.is_none()
}
}

impl Manager {
Expand All @@ -49,41 +114,11 @@ impl Manager {
let conn = connect_to_url(&url, bearer_token)
.await
.wrap_err("failed to connect to server")?;
let bi = wtransport::stream::BiStream::join(
conn.open_bi()
.await
.wrap_err("could not initiate bi stream")?
.await
.wrap_err("could not finish opening bi stream")?,
);

let mut framed = Framed::new(bi);

// Do handshake before anything else
{
framed
.send(Sb::HandshakeRequest)
.await
.wrap_err("failed to send handshake request")?;
let Some(msg) = framed.next().await else {
bail!("Server disconnected before completing handshake");
};
let msg = msg.wrap_err("error while receiving handshake response")?;
ensure!(
msg == Cb::HandshakeResponse,
"invalid message during handshake"
);
}
let manager = StreamPoolManager::new(conn);
let pool = bb8::Pool::builder().build(manager).await.unwrap();

let (request_tx, request_rx) = mpsc::channel(RPC_CAPACITY);
let task = tokio::spawn(manager_task(framed, request_rx));

Ok(Self {
_conn: conn,
url,
task,
request_tx,
})
Ok(Self { pool, url })
}

pub async fn instance_create(&self) -> Result<InstanceId> {
Expand All @@ -102,37 +137,22 @@ impl Manager {
Ok(url)
}

/// Panics if the connection is already dead
async fn request(&self, request: Sb) -> Result<Cb> {
let (response_tx, response_rx) = oneshot::channel();
self.request_tx
.send((request, response_tx))
.await
.wrap_err("failed to send to manager task")?;
response_rx
.await
.wrap_err("failed to receive from manager task")
}

/// Destroys the manager and reaps any errors from its networking task
pub async fn join(self) -> Result<()> {
self.task
.await
.wrap_err("panic in manager task, file a bug report on github uwu")?
.wrap_err("error in task")
}

/// The url of this Manager.
pub fn url(&self) -> &Url {
&self.url
async fn get_framed(&self) -> Result<DropConnectionOnPanic<'_>> {
let pooled_connection = self.pool.get().await.map_err(|e| match e {
bb8::RunError::User(eyre) => {
eyre.wrap_err("get from connection pool failed")
}
bb8::RunError::TimedOut => eyre!("connection pool fetch timed out"),
})?;
Ok(DropConnectionOnPanic { pooled_connection })
}
}

async fn manager_task(
mut framed: Framed,
mut request_rx: mpsc::Receiver<(Sb, oneshot::Sender<Cb>)>,
) -> Result<()> {
while let Some((request, response_tx)) = request_rx.recv().await {
async fn request(&self, request: Sb) -> Result<Cb> {
let mut wrapper = self.get_framed().await?;
let framed = wrapper
.pooled_connection
.as_mut()
.expect("only emptied in Drop impl");
framed
.send(request)
.await
Expand All @@ -142,8 +162,11 @@ async fn manager_task(
.await
.ok_or_eyre("expected a response from the server")?
.wrap_err("error while receiving response")?;
let _ = response_tx.send(response);
Ok(response)
}

/// The url of this Manager.
pub fn url(&self) -> &Url {
&self.url
}
// We only return ok when the manager struct was dropped
Ok(())
}

0 comments on commit 0a9dbcf

Please sign in to comment.