Skip to content

Commit b5c6c17

Browse files
Merge pull request #1481 from Mark-Simulacrum/pool
Pool database connection and test validity
2 parents 686a5fe + 50f2ff8 commit b5c6c17

File tree

7 files changed

+98
-31
lines changed

7 files changed

+98
-31
lines changed

src/db.rs

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use anyhow::Context as _;
22
use native_tls::{Certificate, TlsConnector};
33
use postgres_native_tls::MakeTlsConnector;
4-
pub use tokio_postgres::Client as DbClient;
4+
use std::sync::{Arc, Mutex};
5+
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
6+
use tokio_postgres::Client as DbClient;
57

68
pub mod notifications;
79
pub mod rustc_commits;
@@ -19,7 +21,73 @@ lazy_static::lazy_static! {
1921
};
2022
}
2123

22-
pub async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
24+
pub struct ClientPool {
25+
connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
26+
permits: Arc<Semaphore>,
27+
}
28+
29+
pub struct PooledClient {
30+
client: Option<tokio_postgres::Client>,
31+
#[allow(unused)] // only used for drop impl
32+
permit: OwnedSemaphorePermit,
33+
pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
34+
}
35+
36+
impl Drop for PooledClient {
37+
fn drop(&mut self) {
38+
let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
39+
clients.push(self.client.take().unwrap());
40+
}
41+
}
42+
43+
impl std::ops::Deref for PooledClient {
44+
type Target = tokio_postgres::Client;
45+
46+
fn deref(&self) -> &Self::Target {
47+
self.client.as_ref().unwrap()
48+
}
49+
}
50+
51+
impl std::ops::DerefMut for PooledClient {
52+
fn deref_mut(&mut self) -> &mut Self::Target {
53+
self.client.as_mut().unwrap()
54+
}
55+
}
56+
57+
impl ClientPool {
58+
pub fn new() -> ClientPool {
59+
ClientPool {
60+
connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
61+
permits: Arc::new(Semaphore::new(16)),
62+
}
63+
}
64+
65+
pub async fn get(&self) -> PooledClient {
66+
let permit = self.permits.clone().acquire_owned().await.unwrap();
67+
{
68+
let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
69+
// Pop connections until we hit a non-closed connection (or there are no
70+
// "possibly open" connections left).
71+
while let Some(c) = slots.pop() {
72+
if !c.is_closed() {
73+
return PooledClient {
74+
client: Some(c),
75+
permit,
76+
pool: self.connections.clone(),
77+
};
78+
}
79+
}
80+
}
81+
82+
PooledClient {
83+
client: Some(make_client().await.unwrap()),
84+
permit,
85+
pool: self.connections.clone(),
86+
}
87+
}
88+
}
89+
90+
async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
2391
let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
2492
if db_url.contains("rds.amazonaws.com") {
2593
let cert = &CERTIFICATE_PEM[..];

src/handlers.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use octocrab::Octocrab;
44
use parser::command::{Command, Input};
55
use std::fmt;
66
use std::sync::Arc;
7-
use tokio_postgres::Client as DbClient;
87

98
#[derive(Debug)]
109
pub enum HandlerError {
@@ -247,7 +246,7 @@ command_handlers! {
247246

248247
pub struct Context {
249248
pub github: GithubClient,
250-
pub db: DbClient,
249+
pub db: crate::db::ClientPool,
251250
pub username: String,
252251
pub octocrab: Octocrab,
253252
}

src/handlers/notification.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,22 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
9292
}
9393
};
9494

95+
let client = ctx.db.get().await;
9596
for user in users {
9697
if !users_notified.insert(user.id.unwrap()) {
9798
// Skip users already associated with this event.
9899
continue;
99100
}
100101

101-
if let Err(err) = notifications::record_username(&ctx.db, user.id.unwrap(), user.login)
102+
if let Err(err) = notifications::record_username(&client, user.id.unwrap(), user.login)
102103
.await
103104
.context("failed to record username")
104105
{
105106
log::error!("record username: {:?}", err);
106107
}
107108

108109
if let Err(err) = notifications::record_ping(
109-
&ctx.db,
110+
&client,
110111
&notifications::Notification {
111112
user_id: user.id.unwrap(),
112113
origin_url: event.html_url().unwrap().to_owned(),

src/handlers/rustc_commits.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
7272
let mut sha = bors.merge_sha;
7373
let mut pr = Some(event.issue.number.try_into().unwrap());
7474

75+
let db = ctx.db.get().await;
7576
loop {
7677
// FIXME: ideally we would pull in all the commits here, but unfortunately
7778
// in rust-lang/rust's case there's bors-authored commits that aren't
@@ -101,7 +102,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
101102
};
102103

103104
let res = rustc_commits::record_commit(
104-
&ctx.db,
105+
&db,
105106
rustc_commits::Commit {
106107
sha: gc.sha,
107108
parent_sha: parent_sha.clone(),

src/main.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
3434
.unwrap());
3535
}
3636
if req.uri.path() == "/bors-commit-list" {
37-
let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
37+
let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await;
3838
let res = match res {
3939
Ok(r) => r,
4040
Err(e) => {
@@ -57,7 +57,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
5757
return Ok(Response::builder()
5858
.status(StatusCode::OK)
5959
.body(Body::from(
60-
notification_listing::render(&ctx.db, &*name).await,
60+
notification_listing::render(&ctx.db.get().await, &*name).await,
6161
))
6262
.unwrap());
6363
}
@@ -187,10 +187,8 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
187187
async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
188188
log::info!("Listening on http://{}", addr);
189189

190-
let db_client = db::make_client()
191-
.await
192-
.context("open database connection")?;
193-
db::run_migrations(&db_client)
190+
let pool = db::ClientPool::new();
191+
db::run_migrations(&*pool.get().await)
194192
.await
195193
.context("database migrations")?;
196194

@@ -202,7 +200,7 @@ async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
202200
.expect("Failed to build octograb.");
203201
let ctx = Arc::new(Context {
204202
username: String::from("rustbot"),
205-
db: db_client,
203+
db: pool,
206204
github: gh,
207205
octocrab: oc,
208206
});

src/notification_listing.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use crate::db::notifications::get_notifications;
2-
use crate::db::DbClient;
32

4-
pub async fn render(db: &DbClient, user: &str) -> String {
3+
pub async fn render(db: &crate::db::PooledClient, user: &str) -> String {
54
let notifications = match get_notifications(db, user).await {
65
Ok(n) => n,
76
Err(e) => {

src/zulip.rs

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ fn handle_command<'a>(
119119
};
120120

121121
match next {
122-
Some("acknowledge") | Some("ack") => match acknowledge(gh_id, words).await {
122+
Some("acknowledge") | Some("ack") => match acknowledge(&ctx, gh_id, words).await {
123123
Ok(r) => r,
124124
Err(e) => serde_json::to_string(&Response {
125125
content: &format!(
@@ -139,7 +139,7 @@ fn handle_command<'a>(
139139
})
140140
.unwrap(),
141141
},
142-
Some("move") => match move_notification(gh_id, words).await {
142+
Some("move") => match move_notification(&ctx, gh_id, words).await {
143143
Ok(r) => r,
144144
Err(e) => serde_json::to_string(&Response {
145145
content: &format!(
@@ -149,7 +149,7 @@ fn handle_command<'a>(
149149
})
150150
.unwrap(),
151151
},
152-
Some("meta") => match add_meta_notification(gh_id, words).await {
152+
Some("meta") => match add_meta_notification(&ctx, gh_id, words).await {
153153
Ok(r) => r,
154154
Err(e) => serde_json::to_string(&Response {
155155
content: &format!(
@@ -513,7 +513,11 @@ impl<'a> UpdateMessageApiRequest<'a> {
513513
}
514514
}
515515

516-
async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyhow::Result<String> {
516+
async fn acknowledge(
517+
ctx: &Context,
518+
gh_id: i64,
519+
mut words: impl Iterator<Item = &str>,
520+
) -> anyhow::Result<String> {
517521
let filter = match words.next() {
518522
Some(filter) => {
519523
if words.next().is_some() {
@@ -533,7 +537,8 @@ async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyho
533537
} else {
534538
Identifier::Url(filter)
535539
};
536-
match delete_ping(&mut crate::db::make_client().await?, gh_id, ident).await {
540+
let mut db = ctx.db.get().await;
541+
match delete_ping(&mut *db, gh_id, ident).await {
537542
Ok(deleted) => {
538543
let resp = if deleted.is_empty() {
539544
format!(
@@ -588,7 +593,7 @@ async fn add_notification(
588593
Some(description)
589594
};
590595
match record_ping(
591-
&ctx.db,
596+
&*ctx.db.get().await,
592597
&notifications::Notification {
593598
user_id: gh_id,
594599
origin_url: url.to_owned(),
@@ -612,6 +617,7 @@ async fn add_notification(
612617
}
613618

614619
async fn add_meta_notification(
620+
ctx: &Context,
615621
gh_id: i64,
616622
mut words: impl Iterator<Item = &str>,
617623
) -> anyhow::Result<String> {
@@ -635,14 +641,8 @@ async fn add_meta_notification(
635641
assert_eq!(description.pop(), Some(' ')); // pop trailing space
636642
Some(description)
637643
};
638-
match add_metadata(
639-
&mut crate::db::make_client().await?,
640-
gh_id,
641-
idx,
642-
description.as_deref(),
643-
)
644-
.await
645-
{
644+
let mut db = ctx.db.get().await;
645+
match add_metadata(&mut db, gh_id, idx, description.as_deref()).await {
646646
Ok(()) => Ok(serde_json::to_string(&Response {
647647
content: "Added metadata!",
648648
})
@@ -655,6 +655,7 @@ async fn add_meta_notification(
655655
}
656656

657657
async fn move_notification(
658+
ctx: &Context,
658659
gh_id: i64,
659660
mut words: impl Iterator<Item = &str>,
660661
) -> anyhow::Result<String> {
@@ -676,7 +677,7 @@ async fn move_notification(
676677
.context("to index")?
677678
.checked_sub(1)
678679
.ok_or_else(|| anyhow::anyhow!("1-based indexes"))?;
679-
match move_indices(&mut crate::db::make_client().await?, gh_id, from, to).await {
680+
match move_indices(&mut *ctx.db.get().await, gh_id, from, to).await {
680681
Ok(()) => Ok(serde_json::to_string(&Response {
681682
// to 1-base indices
682683
content: &format!("Moved {} to {}.", from + 1, to + 1),

0 commit comments

Comments
 (0)