Skip to content

Commit

Permalink
provide graceful shutdown interface
Browse files Browse the repository at this point in the history
This introduces a mechanism for politely asking Underway to shutdown. To
do so, a new function, `graceful_shutdown` is provided. Calling this
function will send a notification to a Postgres channel. Workers listen
on this channel and when a message is received will stop processing new
tasks. If they're already processing a task, then they wait until that
task is done or the task timeout has elapsed, whichever is first.

In order to cleanly stop the queue, this function should be used. If
stopping in-progress tasks is safe for your use case, then this can be
ignored and the queue can be stopped without any delay.

Closes #5
  • Loading branch information
maxcountryman committed Oct 12, 2024
1 parent 47e30c0 commit 4b3bf40
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 8 deletions.
22 changes: 22 additions & 0 deletions src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,28 @@ impl<T: Task> QueueBuilder<T, PoolSet> {
}
}

pub(crate) const SHUTDOWN_CHANNEL: &str = "underway_shutdown";

/// Initiates a graceful shutdown by sending a `NOTIFY` to the
/// `underway_shutdown` channel via the `pg_notify` function.
///
/// Workers listen on this channel and when a message is received will stop
/// processing further tasks and wait for in-progress tasks to finish or
/// timeout.
///
/// This can be useful when combined with [`tokio::signal`] to ensure queues are
/// stopped cleanly when stopping your application.
pub async fn graceful_shutdown<'a, E>(executor: E) -> Result
where
E: PgExecutor<'a>,
{
sqlx::query!("select pg_notify($1, $2)", SHUTDOWN_CHANNEL, "")
.execute(executor)
.await?;

Ok(())
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;
Expand Down
118 changes: 110 additions & 8 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,69 @@
//! # });
//! # }
//! ```
//!
//! # Stopping workers safely
//!
//! In order to ensure that workers are interrupted while handling in-progress
//! tasks, the [`graceful_shutdown`](crate::queue::graceful_shutdown) function
//! is provided.
//!
//! This function allows you to politely ask all workers to stop processing new
//! tasks. At the same time, workers are also aware of any in-progress tasks
//! they're working on and will wait for these to be done or timeout.
//!
//! For cases where it's unimportant to wait for tasks to complete, this routine
//! can be ignored.
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use jiff::{Span, ToSpan};
use serde::{de::DeserializeOwned, Serialize};
use sqlx::{postgres::types::PgInterval, PgConnection};
use sqlx::{
postgres::{types::PgInterval, PgListener},
PgConnection,
};
use tokio::sync::Notify;
use tracing::instrument;

use crate::{
job::Job,
queue::{Error as QueueError, Queue},
queue::{Error as QueueError, Queue, SHUTDOWN_CHANNEL},
task::{DequeuedTask, Error as TaskError, Id as TaskId, RetryCount, RetryPolicy, Task},
};
pub(crate) type Result = std::result::Result<(), Error>;

/// A worker that's generic over the task it processes.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct Worker<T: Task> {
queue: Queue<T>,
task: T,

// Indicates the underlying queue has received a shutdown signal.
queue_shutdown: Arc<AtomicBool>,

// Indicates that this worker is processing a task.
processing: AtomicBool,

// Notifies when a task is done processing.
processing_done: Notify,
}

impl<T: Task + Clone> Clone for Worker<T> {
fn clone(&self) -> Self {
Self {
queue: self.queue.clone(),
task: self.task.clone(),
queue_shutdown: self.queue_shutdown.clone(),

// Clones don't share processing state.
processing: false.into(),
processing_done: Notify::new(),
}
}
}

/// Worker errors.
Expand Down Expand Up @@ -154,6 +199,9 @@ where
Self {
queue: job.queue.clone(),
task: job,
queue_shutdown: Arc::new(false.into()),
processing: false.into(),
processing_done: Notify::new(),
}
}
}
Expand All @@ -166,14 +214,23 @@ where
Self {
queue: job.queue.clone(),
task: job.clone(),
queue_shutdown: Arc::new(false.into()),
processing: false.into(),
processing_done: Notify::new(),
}
}
}

impl<T: Task> Worker<T> {
/// Creates a new worker with the given queue and task.
pub const fn new(queue: Queue<T>, task: T) -> Self {
Self { queue, task }
pub fn new(queue: Queue<T>, task: T) -> Self {
Self {
queue,
task,
queue_shutdown: Arc::new(false.into()),
processing: false.into(),
processing_done: Notify::new(),
}
}

/// Runs the worker, processing tasks as they become available.
Expand All @@ -188,11 +245,51 @@ impl<T: Task> Worker<T> {
/// polls.
pub async fn run_every(&self, span: Span) -> Result {
let mut interval = tokio::time::interval(span.try_into()?);
interval.tick().await;

// Set up a listener for shutdown notifications
let mut listener = PgListener::connect_with(&self.queue.pool).await?;
listener.listen(SHUTDOWN_CHANNEL).await?;

loop {
self.process_next_task().await?;
interval.tick().await;
tokio::select! {
_ = listener.recv() => {
tracing::info!("Queue shutdown notification received, initiating shutdown");
self.queue_shutdown.store(true, Ordering::SeqCst);

if self.processing.load(Ordering::SeqCst) {
let task_timeout = self.task.timeout();
tracing::info!(
task.timeout = ?task_timeout,
"Worker is in-progress with a task, waiting for done notification or task timeout"
);

// Wait for either task processing to be done or the task timeout,
// whichever is first.
tokio::select! {
_ = self.processing_done.notified() => {
tracing::debug!("Received done notification");
},
_ = tokio::time::sleep(task_timeout.try_into()?) => {
tracing::debug!("Slept until task timeout");
},
}
}

break;
},

_ = interval.tick() => {
if self.queue_shutdown.load(Ordering::SeqCst) {
tracing::info!("Queue is shutdown so worker will not process more tasks");
break;
}

self.process_next_task().await?;
},
}
}

Ok(())
}

/// Processes the next available task in the queue.
Expand All @@ -213,6 +310,8 @@ impl<T: Task> Worker<T> {
let mut tx = self.queue.pool.begin().await?;

if let Some(task_row) = self.queue.dequeue(&mut tx).await? {
self.processing.store(true, Ordering::SeqCst);

let task_id = task_row.id;
tracing::Span::current().record("task.id", task_id.as_hyphenated().to_string());

Expand Down Expand Up @@ -248,6 +347,9 @@ impl<T: Task> Worker<T> {
}

tx.commit().await?;

self.processing.store(false, Ordering::SeqCst);
self.processing_done.notify_waiters();
}

Ok(())
Expand Down

0 comments on commit 4b3bf40

Please sign in to comment.