Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use cancellation token for shutdown #32

Merged
merged 1 commit into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ tracing = { version = "0.1.40", features = ["log"] }
ulid = { version = "1.1.3", features = ["uuid"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
num_cpus = "1.16.0"
tokio-util = "0.7.12"

[dev-dependencies]
futures = "0.3.30"
22 changes: 16 additions & 6 deletions examples/graceful_shutdown/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::env;

use sqlx::PgPool;
use tokio::signal;
use sqlx::{postgres::PgPoolOptions, PgPool};
use tokio::{signal, task::JoinSet};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use underway::{Job, To};

Expand Down Expand Up @@ -45,15 +45,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// Set up the database connection pool.
let database_url = &env::var("DATABASE_URL").expect("DATABASE_URL should be set");
let pool = PgPool::connect(database_url).await?;
let pool = PgPoolOptions::new()
.max_connections(25)
.connect(database_url)
.await?;

// Run migrations.
underway::MIGRATOR.run(&pool).await?;

// Build the job.
let job = Job::builder()
.step(|_ctx, _input| async move {
let sleep_duration = std::time::Duration::from_secs(5);
let sleep_duration = std::time::Duration::from_secs(10);

tracing::info!(?sleep_duration, "Hello from a long-running task");

Expand All @@ -73,8 +76,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Await the shutdown signal handler in its own task.
tokio::spawn(async move { shutdown_signal(&pool).await });

// The job will run until the queue signals a shutdown.
job.run().await?;
// All jobs will run until the queue signals shutdown.
let mut jobs = JoinSet::new();
for _ in 0..2 {
jobs.spawn({
let job = job.clone();
async move { job.run().await }
});
}
jobs.join_all().await;

Ok(())
}
49 changes: 21 additions & 28 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,7 @@
//! For cases where it's unimportant to wait for tasks to complete, this routine
//! can be ignored.

use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::sync::Arc;

use jiff::{Span, ToSpan};
use serde::Deserialize;
Expand All @@ -129,6 +126,7 @@ use sqlx::{
Acquire, PgConnection,
};
use tokio::{sync::Semaphore, task::JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::instrument;

use crate::{
Expand Down Expand Up @@ -171,8 +169,8 @@ pub struct Worker<T: Task> {
// Limits the number of concurrent `Task::execute` invocations this worker will be allowed.
concurrency_limit: usize,

// Indicates that a queue shutdown signal has been received.
queue_shutdown: Arc<AtomicBool>,
// When this token is cancelled the queue has been shutdown.
shutdown_token: CancellationToken,
}

impl<T: Task> Clone for Worker<T> {
Expand All @@ -181,7 +179,7 @@ impl<T: Task> Clone for Worker<T> {
queue: self.queue.clone(),
task: self.task.clone(),
concurrency_limit: self.concurrency_limit,
queue_shutdown: self.queue_shutdown.clone(),
shutdown_token: self.shutdown_token.clone(),
}
}
}
Expand All @@ -193,7 +191,7 @@ impl<T: Task + Sync> Worker<T> {
queue,
task: Arc::new(task),
concurrency_limit: num_cpus::get(),
queue_shutdown: Arc::new(AtomicBool::new(false)),
shutdown_token: CancellationToken::new(),
}
}

Expand All @@ -216,7 +214,7 @@ impl<T: Task + Sync> Worker<T> {
/// Same as `run` but allows for the configuration of the delay between
/// polls.
pub async fn run_every(&self, period: Span) -> Result {
let mut interval = tokio::time::interval(period.try_into()?);
let mut polling_interval = tokio::time::interval(period.try_into()?);

// Set up a listener for shutdown notifications
let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?;
Expand All @@ -234,35 +232,35 @@ impl<T: Task + Sync> Worker<T> {
notify_shutdown = shutdown_listener.recv() => {
match notify_shutdown {
Ok(_) => {
self.handle_shutdown(&mut processing_tasks).await?;
break
self.shutdown_token.cancel();
},

Err(err) => {
tracing::error!(%err, "Postgres shutdown notification error");
continue
}
}
},
}

_ = self.shutdown_token.cancelled() => {
self.handle_shutdown(&mut processing_tasks).await?;
break
}

// Listen for new pending tasks.
notify_task_change = task_change_listener.recv() => {
match notify_task_change {
Ok(task_change) => self.handle_task_change(task_change, concurrency_limit.clone(), &mut processing_tasks).await?,

Err(err) => {
tracing::error!(%err, "Postgres task change notification error");
continue;
}
};

},
}

// Polling fallback.
_ = interval.tick() => {
self.trigger_task_processing(
concurrency_limit.clone(),
&mut processing_tasks
).await;
// Pending task polling fallback.
_ = polling_interval.tick() => {
self.trigger_task_processing(concurrency_limit.clone(), &mut processing_tasks).await;
}
}
}
Expand All @@ -271,8 +269,6 @@ impl<T: Task + Sync> Worker<T> {
}

async fn handle_shutdown(&self, processing_tasks: &mut JoinSet<()>) -> Result {
self.queue_shutdown.store(true, Ordering::SeqCst);

let task_timeout = self.task.timeout();

tracing::info!(
Expand Down Expand Up @@ -339,19 +335,16 @@ impl<T: Task + Sync> Worker<T> {
processing_tasks.spawn({
let worker = self.clone();
async move {
while !worker.queue_shutdown.load(Ordering::SeqCst) {
while !worker.shutdown_token.is_cancelled() {
match worker.process_next_task().await {
Err(err) => {
tracing::error!(err = %err, "Error processing next task");
continue;
}

Ok(Some(_)) => {
// Since we just processed a task, we'll try again in case there's
// more.
// Since we just processed a task, we'll try again in case there's more.
continue;
}

Ok(None) => {
// We tried to process a task but found none so we'll stop trying.
break;
Expand Down