From 1e2da9e085adb04b6bcdaa44e1a53aae3ce3e972 Mon Sep 17 00:00:00 2001 From: Max Countryman Date: Fri, 25 Oct 2024 15:14:47 -0700 Subject: [PATCH] use cancellation token for shutdown This refactors worker shutdown to use a cancellation token to signal shutdown. By doing so, we can avoid managing an atomic bool. It also makes it possible to provide addtional methods, such as a `shutdown` method, on the worker itself to signal shutdown. --- Cargo.toml | 1 + examples/graceful_shutdown/src/main.rs | 22 ++++++++---- src/worker.rs | 49 +++++++++++--------------- 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3583666..a8e2ff2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ tracing = { version = "0.1.40", features = ["log"] } ulid = { version = "1.1.3", features = ["uuid"] } uuid = { version = "1.10.0", features = ["v4", "serde"] } num_cpus = "1.16.0" +tokio-util = "0.7.12" [dev-dependencies] futures = "0.3.30" diff --git a/examples/graceful_shutdown/src/main.rs b/examples/graceful_shutdown/src/main.rs index dcca5f0..2526c8d 100644 --- a/examples/graceful_shutdown/src/main.rs +++ b/examples/graceful_shutdown/src/main.rs @@ -1,7 +1,7 @@ use std::env; -use sqlx::PgPool; -use tokio::signal; +use sqlx::{postgres::PgPoolOptions, PgPool}; +use tokio::{signal, task::JoinSet}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; use underway::{Job, To}; @@ -45,7 +45,10 @@ async fn main() -> Result<(), Box> { // Set up the database connection pool. let database_url = &env::var("DATABASE_URL").expect("DATABASE_URL should be set"); - let pool = PgPool::connect(database_url).await?; + let pool = PgPoolOptions::new() + .max_connections(25) + .connect(database_url) + .await?; // Run migrations. underway::MIGRATOR.run(&pool).await?; @@ -53,7 +56,7 @@ async fn main() -> Result<(), Box> { // Build the job. let job = Job::builder() .step(|_ctx, _input| async move { - let sleep_duration = std::time::Duration::from_secs(5); + let sleep_duration = std::time::Duration::from_secs(10); tracing::info!(?sleep_duration, "Hello from a long-running task"); @@ -73,8 +76,15 @@ async fn main() -> Result<(), Box> { // Await the shutdown signal handler in its own task. tokio::spawn(async move { shutdown_signal(&pool).await }); - // The job will run until the queue signals a shutdown. - job.run().await?; + // All jobs will run until the queue signals shutdown. + let mut jobs = JoinSet::new(); + for _ in 0..2 { + jobs.spawn({ + let job = job.clone(); + async move { job.run().await } + }); + } + jobs.join_all().await; Ok(()) } diff --git a/src/worker.rs b/src/worker.rs index 3b2a142..4480989 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -117,10 +117,7 @@ //! For cases where it's unimportant to wait for tasks to complete, this routine //! can be ignored. -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; +use std::sync::Arc; use jiff::{Span, ToSpan}; use serde::Deserialize; @@ -129,6 +126,7 @@ use sqlx::{ Acquire, PgConnection, }; use tokio::{sync::Semaphore, task::JoinSet}; +use tokio_util::sync::CancellationToken; use tracing::instrument; use crate::{ @@ -171,8 +169,8 @@ pub struct Worker { // Limits the number of concurrent `Task::execute` invocations this worker will be allowed. concurrency_limit: usize, - // Indicates that a queue shutdown signal has been received. - queue_shutdown: Arc, + // When this token is cancelled the queue has been shutdown. + shutdown_token: CancellationToken, } impl Clone for Worker { @@ -181,7 +179,7 @@ impl Clone for Worker { queue: self.queue.clone(), task: self.task.clone(), concurrency_limit: self.concurrency_limit, - queue_shutdown: self.queue_shutdown.clone(), + shutdown_token: self.shutdown_token.clone(), } } } @@ -193,7 +191,7 @@ impl Worker { queue, task: Arc::new(task), concurrency_limit: num_cpus::get(), - queue_shutdown: Arc::new(AtomicBool::new(false)), + shutdown_token: CancellationToken::new(), } } @@ -216,7 +214,7 @@ impl Worker { /// Same as `run` but allows for the configuration of the delay between /// polls. pub async fn run_every(&self, period: Span) -> Result { - let mut interval = tokio::time::interval(period.try_into()?); + let mut polling_interval = tokio::time::interval(period.try_into()?); // Set up a listener for shutdown notifications let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?; @@ -234,35 +232,35 @@ impl Worker { notify_shutdown = shutdown_listener.recv() => { match notify_shutdown { Ok(_) => { - self.handle_shutdown(&mut processing_tasks).await?; - break + self.shutdown_token.cancel(); }, Err(err) => { tracing::error!(%err, "Postgres shutdown notification error"); - continue } } - }, + } + _ = self.shutdown_token.cancelled() => { + self.handle_shutdown(&mut processing_tasks).await?; + break + } + + // Listen for new pending tasks. notify_task_change = task_change_listener.recv() => { match notify_task_change { Ok(task_change) => self.handle_task_change(task_change, concurrency_limit.clone(), &mut processing_tasks).await?, Err(err) => { tracing::error!(%err, "Postgres task change notification error"); - continue; } }; - }, + } - // Polling fallback. - _ = interval.tick() => { - self.trigger_task_processing( - concurrency_limit.clone(), - &mut processing_tasks - ).await; + // Pending task polling fallback. + _ = polling_interval.tick() => { + self.trigger_task_processing(concurrency_limit.clone(), &mut processing_tasks).await; } } } @@ -271,8 +269,6 @@ impl Worker { } async fn handle_shutdown(&self, processing_tasks: &mut JoinSet<()>) -> Result { - self.queue_shutdown.store(true, Ordering::SeqCst); - let task_timeout = self.task.timeout(); tracing::info!( @@ -339,19 +335,16 @@ impl Worker { processing_tasks.spawn({ let worker = self.clone(); async move { - while !worker.queue_shutdown.load(Ordering::SeqCst) { + while !worker.shutdown_token.is_cancelled() { match worker.process_next_task().await { Err(err) => { tracing::error!(err = %err, "Error processing next task"); continue; } - Ok(Some(_)) => { - // Since we just processed a task, we'll try again in case there's - // more. + // Since we just processed a task, we'll try again in case there's more. continue; } - Ok(None) => { // We tried to process a task but found none so we'll stop trying. break;