Skip to content

Commit 50f2ff8

Browse files
Pool database connection and test validity
This should mitigate the periodic downtime we experience if the database connection dies, which previously required an external restart of the service. Now, if the connection is closed, the next request will automatically either use a different pooled connection or open a new one. The pooling implementation is partially extracted from rustc-perf, which has not encountered the database errors in production that triagebot has.
1 parent 686a5fe commit 50f2ff8

File tree

7 files changed

+98
-31
lines changed

7 files changed

+98
-31
lines changed

src/db.rs

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use anyhow::Context as _;
22
use native_tls::{Certificate, TlsConnector};
33
use postgres_native_tls::MakeTlsConnector;
4-
pub use tokio_postgres::Client as DbClient;
4+
use std::sync::{Arc, Mutex};
5+
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
6+
use tokio_postgres::Client as DbClient;
57

68
pub mod notifications;
79
pub mod rustc_commits;
@@ -19,7 +21,73 @@ lazy_static::lazy_static! {
1921
};
2022
}
2123

22-
pub async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
24+
pub struct ClientPool {
25+
connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
26+
permits: Arc<Semaphore>,
27+
}
28+
29+
pub struct PooledClient {
30+
client: Option<tokio_postgres::Client>,
31+
#[allow(unused)] // only used for drop impl
32+
permit: OwnedSemaphorePermit,
33+
pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
34+
}
35+
36+
impl Drop for PooledClient {
37+
fn drop(&mut self) {
38+
let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
39+
clients.push(self.client.take().unwrap());
40+
}
41+
}
42+
43+
impl std::ops::Deref for PooledClient {
44+
type Target = tokio_postgres::Client;
45+
46+
fn deref(&self) -> &Self::Target {
47+
self.client.as_ref().unwrap()
48+
}
49+
}
50+
51+
impl std::ops::DerefMut for PooledClient {
52+
fn deref_mut(&mut self) -> &mut Self::Target {
53+
self.client.as_mut().unwrap()
54+
}
55+
}
56+
57+
impl ClientPool {
58+
pub fn new() -> ClientPool {
59+
ClientPool {
60+
connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
61+
permits: Arc::new(Semaphore::new(16)),
62+
}
63+
}
64+
65+
pub async fn get(&self) -> PooledClient {
66+
let permit = self.permits.clone().acquire_owned().await.unwrap();
67+
{
68+
let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
69+
// Pop connections until we hit a non-closed connection (or there are no
70+
// "possibly open" connections left).
71+
while let Some(c) = slots.pop() {
72+
if !c.is_closed() {
73+
return PooledClient {
74+
client: Some(c),
75+
permit,
76+
pool: self.connections.clone(),
77+
};
78+
}
79+
}
80+
}
81+
82+
PooledClient {
83+
client: Some(make_client().await.unwrap()),
84+
permit,
85+
pool: self.connections.clone(),
86+
}
87+
}
88+
}
89+
90+
async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
2391
let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
2492
if db_url.contains("rds.amazonaws.com") {
2593
let cert = &CERTIFICATE_PEM[..];

src/handlers.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use octocrab::Octocrab;
44
use parser::command::{Command, Input};
55
use std::fmt;
66
use std::sync::Arc;
7-
use tokio_postgres::Client as DbClient;
87

98
#[derive(Debug)]
109
pub enum HandlerError {
@@ -247,7 +246,7 @@ command_handlers! {
247246

248247
pub struct Context {
249248
pub github: GithubClient,
250-
pub db: DbClient,
249+
pub db: crate::db::ClientPool,
251250
pub username: String,
252251
pub octocrab: Octocrab,
253252
}

src/handlers/notification.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,22 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
9292
}
9393
};
9494

95+
let client = ctx.db.get().await;
9596
for user in users {
9697
if !users_notified.insert(user.id.unwrap()) {
9798
// Skip users already associated with this event.
9899
continue;
99100
}
100101

101-
if let Err(err) = notifications::record_username(&ctx.db, user.id.unwrap(), user.login)
102+
if let Err(err) = notifications::record_username(&client, user.id.unwrap(), user.login)
102103
.await
103104
.context("failed to record username")
104105
{
105106
log::error!("record username: {:?}", err);
106107
}
107108

108109
if let Err(err) = notifications::record_ping(
109-
&ctx.db,
110+
&client,
110111
&notifications::Notification {
111112
user_id: user.id.unwrap(),
112113
origin_url: event.html_url().unwrap().to_owned(),

src/handlers/rustc_commits.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
7272
let mut sha = bors.merge_sha;
7373
let mut pr = Some(event.issue.number.try_into().unwrap());
7474

75+
let db = ctx.db.get().await;
7576
loop {
7677
// FIXME: ideally we would pull in all the commits here, but unfortunately
7778
// in rust-lang/rust's case there's bors-authored commits that aren't
@@ -101,7 +102,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
101102
};
102103

103104
let res = rustc_commits::record_commit(
104-
&ctx.db,
105+
&db,
105106
rustc_commits::Commit {
106107
sha: gc.sha,
107108
parent_sha: parent_sha.clone(),

src/main.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
3434
.unwrap());
3535
}
3636
if req.uri.path() == "/bors-commit-list" {
37-
let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
37+
let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await;
3838
let res = match res {
3939
Ok(r) => r,
4040
Err(e) => {
@@ -57,7 +57,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
5757
return Ok(Response::builder()
5858
.status(StatusCode::OK)
5959
.body(Body::from(
60-
notification_listing::render(&ctx.db, &*name).await,
60+
notification_listing::render(&ctx.db.get().await, &*name).await,
6161
))
6262
.unwrap());
6363
}
@@ -187,10 +187,8 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
187187
async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
188188
log::info!("Listening on http://{}", addr);
189189

190-
let db_client = db::make_client()
191-
.await
192-
.context("open database connection")?;
193-
db::run_migrations(&db_client)
190+
let pool = db::ClientPool::new();
191+
db::run_migrations(&*pool.get().await)
194192
.await
195193
.context("database migrations")?;
196194

@@ -202,7 +200,7 @@ async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
202200
.expect("Failed to build octograb.");
203201
let ctx = Arc::new(Context {
204202
username: String::from("rustbot"),
205-
db: db_client,
203+
db: pool,
206204
github: gh,
207205
octocrab: oc,
208206
});

src/notification_listing.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use crate::db::notifications::get_notifications;
2-
use crate::db::DbClient;
32

4-
pub async fn render(db: &DbClient, user: &str) -> String {
3+
pub async fn render(db: &crate::db::PooledClient, user: &str) -> String {
54
let notifications = match get_notifications(db, user).await {
65
Ok(n) => n,
76
Err(e) => {

src/zulip.rs

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ fn handle_command<'a>(
119119
};
120120

121121
match next {
122-
Some("acknowledge") | Some("ack") => match acknowledge(gh_id, words).await {
122+
Some("acknowledge") | Some("ack") => match acknowledge(&ctx, gh_id, words).await {
123123
Ok(r) => r,
124124
Err(e) => serde_json::to_string(&Response {
125125
content: &format!(
@@ -139,7 +139,7 @@ fn handle_command<'a>(
139139
})
140140
.unwrap(),
141141
},
142-
Some("move") => match move_notification(gh_id, words).await {
142+
Some("move") => match move_notification(&ctx, gh_id, words).await {
143143
Ok(r) => r,
144144
Err(e) => serde_json::to_string(&Response {
145145
content: &format!(
@@ -149,7 +149,7 @@ fn handle_command<'a>(
149149
})
150150
.unwrap(),
151151
},
152-
Some("meta") => match add_meta_notification(gh_id, words).await {
152+
Some("meta") => match add_meta_notification(&ctx, gh_id, words).await {
153153
Ok(r) => r,
154154
Err(e) => serde_json::to_string(&Response {
155155
content: &format!(
@@ -513,7 +513,11 @@ impl<'a> UpdateMessageApiRequest<'a> {
513513
}
514514
}
515515

516-
async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyhow::Result<String> {
516+
async fn acknowledge(
517+
ctx: &Context,
518+
gh_id: i64,
519+
mut words: impl Iterator<Item = &str>,
520+
) -> anyhow::Result<String> {
517521
let filter = match words.next() {
518522
Some(filter) => {
519523
if words.next().is_some() {
@@ -533,7 +537,8 @@ async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyho
533537
} else {
534538
Identifier::Url(filter)
535539
};
536-
match delete_ping(&mut crate::db::make_client().await?, gh_id, ident).await {
540+
let mut db = ctx.db.get().await;
541+
match delete_ping(&mut *db, gh_id, ident).await {
537542
Ok(deleted) => {
538543
let resp = if deleted.is_empty() {
539544
format!(
@@ -588,7 +593,7 @@ async fn add_notification(
588593
Some(description)
589594
};
590595
match record_ping(
591-
&ctx.db,
596+
&*ctx.db.get().await,
592597
&notifications::Notification {
593598
user_id: gh_id,
594599
origin_url: url.to_owned(),
@@ -612,6 +617,7 @@ async fn add_notification(
612617
}
613618

614619
async fn add_meta_notification(
620+
ctx: &Context,
615621
gh_id: i64,
616622
mut words: impl Iterator<Item = &str>,
617623
) -> anyhow::Result<String> {
@@ -635,14 +641,8 @@ async fn add_meta_notification(
635641
assert_eq!(description.pop(), Some(' ')); // pop trailing space
636642
Some(description)
637643
};
638-
match add_metadata(
639-
&mut crate::db::make_client().await?,
640-
gh_id,
641-
idx,
642-
description.as_deref(),
643-
)
644-
.await
645-
{
644+
let mut db = ctx.db.get().await;
645+
match add_metadata(&mut db, gh_id, idx, description.as_deref()).await {
646646
Ok(()) => Ok(serde_json::to_string(&Response {
647647
content: "Added metadata!",
648648
})
@@ -655,6 +655,7 @@ async fn add_meta_notification(
655655
}
656656

657657
async fn move_notification(
658+
ctx: &Context,
658659
gh_id: i64,
659660
mut words: impl Iterator<Item = &str>,
660661
) -> anyhow::Result<String> {
@@ -676,7 +677,7 @@ async fn move_notification(
676677
.context("to index")?
677678
.checked_sub(1)
678679
.ok_or_else(|| anyhow::anyhow!("1-based indexes"))?;
679-
match move_indices(&mut crate::db::make_client().await?, gh_id, from, to).await {
680+
match move_indices(&mut *ctx.db.get().await, gh_id, from, to).await {
680681
Ok(()) => Ok(serde_json::to_string(&Response {
681682
// to 1-base indices
682683
content: &format!("Moved {} to {}.", from + 1, to + 1),

0 commit comments

Comments
 (0)