diff --git a/.sqlx/query-45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49.json b/.sqlx/query-45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49.json new file mode 100644 index 0000000..3f3fae0 --- /dev/null +++ b/.sqlx/query-45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49.json @@ -0,0 +1,35 @@ +{ + "db_name": "PostgreSQL", + "query": "\n select count(*)\n from underway.task\n where state = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "count", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + { + "Custom": { + "name": "underway.task_state", + "kind": { + "Enum": [ + "pending", + "in_progress", + "succeeded", + "cancelled", + "failed" + ] + } + } + } + ] + }, + "nullable": [ + null + ] + }, + "hash": "45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49" +} diff --git a/.sqlx/query-54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4.json b/.sqlx/query-54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4.json new file mode 100644 index 0000000..cd6c018 --- /dev/null +++ b/.sqlx/query-54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4.json @@ -0,0 +1,23 @@ +{ + "db_name": "PostgreSQL", + "query": "select pg_notify($1, $2)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "pg_notify", + "type_info": "Void" + } + ], + "parameters": { + "Left": [ + "Text", + "Text" + ] + }, + "nullable": [ + null + ] + }, + "hash": "54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4" +} diff --git a/Cargo.toml b/Cargo.toml index 80f3ee5..c9b1a15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ tokio = { version = "1.40.0", features = [ tracing = { version = "0.1.40", features = ["log"] } ulid = { version = "1.1.3", features = ["uuid"] } uuid = { version = "1.10.0", features = ["v4"] } +num_cpus = "1.16.0" [dev-dependencies] futures = "0.3.30" diff --git a/src/queue.rs b/src/queue.rs index 21ab344..0865212 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -959,6 +959,28 @@ impl QueueBuilder { } } +pub(crate) const SHUTDOWN_CHANNEL: &str = "underway_shutdown"; + +/// Initiates a graceful shutdown by sending a `NOTIFY` to the +/// `underway_shutdown` channel via the `pg_notify` function. +/// +/// Workers listen on this channel and when a message is received will stop +/// processing further tasks and wait for in-progress tasks to finish or +/// timeout. +/// +/// This can be useful when combined with [`tokio::signal`] to ensure queues are +/// stopped cleanly when stopping your application. +pub async fn graceful_shutdown<'a, E>(executor: E) -> Result +where + E: PgExecutor<'a>, +{ + sqlx::query!("select pg_notify($1, $2)", SHUTDOWN_CHANNEL, "") + .execute(executor) + .await?; + + Ok(()) +} + #[cfg(test)] mod tests { use std::collections::HashSet; diff --git a/src/worker.rs b/src/worker.rs index 71e1b05..dfb7d2d 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -101,24 +101,63 @@ //! # }); //! # } //! ``` +//! +//! # Stopping workers safely +//! +//! In order to ensure that workers are interrupted while handling in-progress +//! tasks, the [`graceful_shutdown`](crate::queue::graceful_shutdown) function +//! is provided. +//! +//! This function allows you to politely ask all workers to stop processing new +//! tasks. At the same time, workers are also aware of any in-progress tasks +//! they're working on and will wait for these to be done or timeout. +//! +//! For cases where it's unimportant to wait for tasks to complete, this routine +//! can be ignored. + +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use jiff::{Span, ToSpan}; use serde::{de::DeserializeOwned, Serialize}; -use sqlx::{postgres::types::PgInterval, PgConnection}; +use sqlx::{ + postgres::{types::PgInterval, PgListener}, + PgConnection, +}; +use tokio::{sync::Semaphore, task::JoinSet}; use tracing::instrument; use crate::{ job::Job, - queue::{Error as QueueError, Queue}, + queue::{Error as QueueError, Queue, SHUTDOWN_CHANNEL}, task::{DequeuedTask, Error as TaskError, Id as TaskId, RetryCount, RetryPolicy, Task}, }; pub(crate) type Result = std::result::Result<(), Error>; /// A worker that's generic over the task it processes. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Worker { queue: Queue, - task: T, + task: Arc, + + // Limits the number of concurrent `Task::execute` invocations this worker will be allowed. + concurrency_limit: usize, + + // Indicates the underlying queue has received a shutdown signal. + queue_shutdown: Arc, +} + +impl Clone for Worker { + fn clone(&self) -> Self { + Self { + queue: self.queue.clone(), + task: Arc::clone(&self.task), + concurrency_limit: self.concurrency_limit, + queue_shutdown: self.queue_shutdown.clone(), + } + } } /// Worker errors. @@ -153,7 +192,9 @@ where fn from(job: Job) -> Self { Self { queue: job.queue.clone(), - task: job, + task: Arc::new(job), + concurrency_limit: num_cpus::get(), + queue_shutdown: Arc::new(false.into()), } } } @@ -165,15 +206,28 @@ where fn from(job: &Job) -> Self { Self { queue: job.queue.clone(), - task: job.clone(), + task: Arc::new(job.to_owned()), + concurrency_limit: num_cpus::get(), + queue_shutdown: Arc::new(false.into()), } } } -impl Worker { +impl Worker { /// Creates a new worker with the given queue and task. - pub const fn new(queue: Queue, task: T) -> Self { - Self { queue, task } + pub fn new(queue: Queue, task: T) -> Self { + Self { + queue, + task: Arc::new(task), + concurrency_limit: num_cpus::get(), + queue_shutdown: Arc::new(false.into()), + } + } + + /// Sets the concurrency limit for this worker. + pub fn concurrency_limit(mut self, concurrency_limit: usize) -> Self { + self.concurrency_limit = concurrency_limit; + self } /// Runs the worker, processing tasks as they become available. @@ -188,11 +242,80 @@ impl Worker { /// polls. pub async fn run_every(&self, span: Span) -> Result { let mut interval = tokio::time::interval(span.try_into()?); - interval.tick().await; + + // Set up a listener for shutdown notifications + let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?; + shutdown_listener.listen(SHUTDOWN_CHANNEL).await?; + + let concurrency_limit = Arc::new(Semaphore::new(self.concurrency_limit)); + let mut processing_tasks = JoinSet::new(); + loop { - self.process_next_task().await?; - interval.tick().await; + tokio::select! { + shutdown_notif = shutdown_listener.recv() => { + if let Err(err) = shutdown_notif { + tracing::error!(%err, "NOTIFY resulted in an error"); + continue; + } + + self.queue_shutdown.store(true, Ordering::SeqCst); + + let task_timeout = self.task.timeout(); + + tracing::info!( + task.timeout = ?task_timeout, + "Waiting for all processing tasks or timeout" + ); + + // Try to join all the processing tasks before the task timeout. + let shutdown_result = tokio::time::timeout( + task_timeout.try_into()?, + async { + while let Some(res) = processing_tasks.join_next().await { + if let Err(err) = res { + tracing::error!(%err, "A processing task failed during shutdown"); + } + } + } + ).await; + + match shutdown_result { + Ok(_) => { + tracing::debug!("All processing tasks completed gracefully"); + }, + Err(_) => { + let remaining_tasks = processing_tasks.len(); + tracing::warn!(remaining_tasks, "Reached task timeout before all tasks completed"); + }, + } + + break; + }, + + _ = interval.tick() => { + if self.queue_shutdown.load(Ordering::SeqCst) { + tracing::info!("Queue is shutdown so no new tasks will be processed"); + break; + } + + let permit = concurrency_limit.clone().acquire_owned().await.expect("Concurrency limit semaphore should be open"); + processing_tasks.spawn({ + // TODO: Rather than clone the worker, we could have a separate type that + // owns task processing. + let worker = self.clone(); + + async move { + if let Err(err) = worker.process_next_task().await { + tracing::error!(%err, "Error processing next task"); + } + drop(permit); + } + }); + } + } } + + Ok(()) } /// Processes the next available task in the queue. @@ -234,6 +357,7 @@ impl Worker { Ok(_) => { self.queue.mark_task_succeeded(&mut *tx, task_id).await?; } + Err(err) => { self.handle_task_error(err, &mut tx, task_id, task_row) .await?; @@ -357,13 +481,16 @@ pub(crate) fn pg_interval_to_span( #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{sync::Arc, time::Duration as StdDuration}; use sqlx::PgPool; use tokio::sync::Mutex; use super::*; - use crate::task::{Result as TaskResult, State as TaskState}; + use crate::{ + queue::graceful_shutdown, + task::{Result as TaskResult, State as TaskState}, + }; struct TestTask; @@ -461,4 +588,92 @@ mod tests { Ok(()) } + + #[sqlx::test] + async fn test_graceful_shutdown(pool: PgPool) -> sqlx::Result<(), Error> { + let queue = Queue::builder() + .name("test_queue") + .pool(pool.clone()) + .build() + .await?; + + #[derive(Debug, Clone)] + struct LongRunningTask; + + impl Task for LongRunningTask { + type Input = (); + + async fn execute(&self, _: Self::Input) -> TaskResult { + tokio::time::sleep(StdDuration::from_secs(1)).await; + Ok(()) + } + } + + // Enqueue some tasks + for _ in 0..5 { + queue.enqueue(&pool, &LongRunningTask, ()).await?; + } + + // Start workers + let worker = Worker::new(queue.clone(), LongRunningTask); + for _ in 0..2 { + let worker = worker.clone(); + tokio::spawn(async move { worker.run().await }); + } + + let pending = sqlx::query_scalar!( + r#" + select count(*) + from underway.task + where state = $1 + "#, + TaskState::Pending as _ + ) + .fetch_one(&pool) + .await?; + assert_eq!(pending, Some(5)); + + // Wait briefly to ensure workers are listening + tokio::time::sleep(StdDuration::from_secs(2)).await; + + // Initiate graceful shutdown + graceful_shutdown(&pool).await?; + + // Wait for tasks to be done + tokio::time::sleep(StdDuration::from_secs(5)).await; + + let succeeded = sqlx::query_scalar!( + r#" + select count(*) + from underway.task + where state = $1 + "#, + TaskState::Succeeded as _ + ) + .fetch_one(&pool) + .await?; + assert_eq!(succeeded, Some(5)); + + // New tasks shouldn't be processed + queue.enqueue(&pool, &LongRunningTask, ()).await?; + + // Wait to ensure a worker would have seen the new task if one were processing + tokio::time::sleep(StdDuration::from_secs(5)).await; + + let succeeded = sqlx::query_scalar!( + r#" + select count(*) + from underway.task + where state = $1 + "#, + TaskState::Succeeded as _ + ) + .fetch_one(&pool) + .await?; + + // Succeeded count remains the same since workers have been shutdown + assert_eq!(succeeded, Some(5)); + + Ok(()) + } }