Skip to content

Commit

Permalink
Add example for wrapping PostgresConnectionManager to add state
Browse files Browse the repository at this point in the history
With this example we show how easy it is to wrap the postgres connection
pool to prepare several queries when a connection is created. Then, once
a connection is checked out, this examples shows how easy it is to use
the custom state to pull a prepared statement.

Inspired by #110 which requests this kind of behavior in bb8 directly,
but for now it's easy to extend any connection manager with your custom
needs.
  • Loading branch information
film42 authored and djc committed Oct 28, 2021
1 parent 4e9862c commit 687d394
Showing 1 changed file with 133 additions and 0 deletions.
133 changes: 133 additions & 0 deletions postgres/examples/custom_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use std::collections::BTreeMap;
use std::ops::Deref;
use std::str::FromStr;

use async_trait::async_trait;
use bb8::{CustomizeConnection, Pool};
use bb8_postgres::PostgresConnectionManager;
use tokio_postgres::config::Config;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{Client, Error, Socket, Statement};

// Select some static data from a Postgres DB
//
// The simplest way to start the db is using Docker:
// docker run --name gotham-middleware-postgres -e POSTGRES_PASSWORD=mysecretpassword -p 5432:5432 -d postgres
#[tokio::main]
async fn main() {
let config =
tokio_postgres::config::Config::from_str("postgresql://postgres:docker@localhost:5432")
.unwrap();
let pg_mgr = CustomPostgresConnectionManager::new(config, tokio_postgres::NoTls);

let pool = Pool::builder()
.connection_customizer(Box::new(Customizer))
.build(pg_mgr)
.await
.expect("build error");

let connection = pool.get().await.expect("pool error");

let row = connection
.query_one(
connection
.custom_state
.get(&QueryName::Addition)
.expect("statement not predefined"),
&[],
)
.await
.expect("query failed");

println!("result: {}", row.get::<usize, i32>(0));
}

#[derive(Debug)]
struct Customizer;

#[async_trait]
impl<'a> CustomizeConnection<CustomPostgresConnection, Error> for Customizer {
async fn on_acquire(&self, conn: &mut CustomPostgresConnection) -> Result<(), Error> {
conn.custom_state
.insert(QueryName::BasicSelect, conn.prepare("SELECT 1").await?);

conn.custom_state
.insert(QueryName::Addition, conn.prepare("SELECT 1 + 1 + 1").await?);

Ok(())
}
}

struct CustomPostgresConnection {
inner: Client,
custom_state: BTreeMap<QueryName, Statement>,
}

impl CustomPostgresConnection {
fn new(inner: Client) -> Self {
Self {
inner,
custom_state: Default::default(),
}
}
}

impl Deref for CustomPostgresConnection {
type Target = Client;

fn deref(&self) -> &Self::Target {
&self.inner
}
}

struct CustomPostgresConnectionManager<Tls>
where
Tls: MakeTlsConnect<Socket>,
{
inner: PostgresConnectionManager<Tls>,
}

impl<Tls> CustomPostgresConnectionManager<Tls>
where
Tls: MakeTlsConnect<Socket>,
{
pub fn new(config: Config, tls: Tls) -> Self {
Self {
inner: PostgresConnectionManager::new(config, tls),
}
}
}

#[async_trait]
impl<Tls> bb8::ManageConnection for CustomPostgresConnectionManager<Tls>
where
Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
<Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
<Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
<<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
type Connection = CustomPostgresConnection;
type Error = Error;

async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let conn = self.inner.connect().await?;
Ok(CustomPostgresConnection::new(conn))
}

async fn is_valid(
&self,
conn: &mut bb8::PooledConnection<'_, Self>,
) -> Result<(), Self::Error> {
conn.simple_query("").await.map(|_| ())
}

fn has_broken(&self, conn: &mut Self::Connection) -> bool {
self.inner.has_broken(&mut conn.inner)
}
}

#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)]
enum QueryName {
BasicSelect,
Addition,
}

0 comments on commit 687d394

Please sign in to comment.