From e48a4eab99d05813c39b7a3a09ae60ed6c37bd5e Mon Sep 17 00:00:00 2001 From: Max Countryman Date: Fri, 25 Oct 2024 18:21:17 -0700 Subject: [PATCH] refactor scheduler This refactors the scheduler to be more like the worker: e.g. it now manages a cancellation token and uses it to manage the shutdown state. Again similarly to the worker, it listens on the same shutdown channel and upon a notify from postgres will cancel its cancellation token thereby stopping the scheduler. Coupled with these changes is a breaking change to the queue interface: instead of taking ownership of input, enqueue takes a reference. A similar change should be made for the schedule methods, but that is left out of this patch intentionally. --- examples/basic/src/main.rs | 2 +- examples/multitask/src/main.rs | 4 +- examples/step/src/main.rs | 2 +- examples/tracing/src/main.rs | 2 +- src/job.rs | 59 +++++++++++-------------- src/lib.rs | 4 +- src/queue.rs | 32 +++++++------- src/scheduler.rs | 79 +++++++++++++++++++++++++--------- src/worker.rs | 8 ++-- 9 files changed, 111 insertions(+), 81 deletions(-) diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index bbb7db3..41b372d 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -43,7 +43,7 @@ async fn main() -> Result<(), Box> { // Enqueue a job task. let task_id = job - .enqueue(WelcomeEmail { + .enqueue(&WelcomeEmail { user_id: 42, email: "ferris@example.com".to_string(), name: "Ferris".to_string(), diff --git a/examples/multitask/src/main.rs b/examples/multitask/src/main.rs index e7d7ab2..867009b 100644 --- a/examples/multitask/src/main.rs +++ b/examples/multitask/src/main.rs @@ -30,7 +30,7 @@ impl WelcomeEmailTask { // This ensures our task-specific configuration is applied. let welcome_email_task = self.into(); queue - .enqueue(pool, &welcome_email_task, TaskInput::WelcomeEmail(input)) + .enqueue(pool, &welcome_email_task, &TaskInput::WelcomeEmail(input)) .await } } @@ -63,7 +63,7 @@ impl OrderTask { // This ensures our task-specific configuration is applied. let order_task = self.into(); queue - .enqueue(pool, &order_task, TaskInput::Order(input)) + .enqueue(pool, &order_task, &TaskInput::Order(input)) .await } } diff --git a/examples/step/src/main.rs b/examples/step/src/main.rs index 714b6ce..fd5e557 100644 --- a/examples/step/src/main.rs +++ b/examples/step/src/main.rs @@ -65,7 +65,7 @@ async fn main() -> Result<(), Box> { .await?; // Enqueue the first step. - job.enqueue(Start { n: 42 }).await?; + job.enqueue(&Start { n: 42 }).await?; // Run the job worker. job.run().await?; diff --git a/examples/tracing/src/main.rs b/examples/tracing/src/main.rs index 7fdc754..b4c93d6 100644 --- a/examples/tracing/src/main.rs +++ b/examples/tracing/src/main.rs @@ -52,7 +52,7 @@ async fn main() -> Result<(), Box> { // Enqueue a job task. let task_id = job - .enqueue(WelcomeEmail { + .enqueue(&WelcomeEmail { user_id: 42, email: "ferris@example.com".to_string(), name: "Ferris".to_string(), diff --git a/src/job.rs b/src/job.rs index 09c72c8..c7466a1 100644 --- a/src/job.rs +++ b/src/job.rs @@ -346,7 +346,7 @@ //! .await?; //! //! // Enqueue a new job with the given input `()`. -//! job.enqueue(()).await?; +//! job.enqueue(&()).await?; //! # Ok::<(), Box>(()) //! # }); //! # } @@ -381,7 +381,7 @@ //! .queue(queue) //! .build(); //! -//! job.enqueue(()).await?; +//! job.enqueue(&()).await?; //! # Ok::<(), Box>(()) //! # }); //! # } @@ -419,7 +419,7 @@ //! .await?; //! //! // Enqueue a new job with a slightly more interesting value. -//! job.enqueue(Input { +//! job.enqueue(&Input { //! bucket_name: "my_bucket".to_string(), //! }) //! .await?; @@ -470,7 +470,7 @@ //! # //! //! // Enqueue using a transaction that we supply. -//! job.enqueue_using(&mut *tx, ()).await?; +//! job.enqueue_using(&mut *tx, &()).await?; //! //! # /* //! /* ...And more intervening logic involving `tx`. */ @@ -637,7 +637,7 @@ use jiff::Span; use sealed::JobState; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use sqlx::{PgExecutor, PgPool, Postgres, Transaction}; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; use tracing::instrument; use ulid::Ulid; @@ -744,7 +744,7 @@ where } /// Enqueue the job using a connection from the queue's pool. - pub async fn enqueue(&self, input: I) -> Result { + pub async fn enqueue(&self, input: &I) -> Result { let mut conn = self.queue.pool.acquire().await?; self.enqueue_using(&mut *conn, input).await } @@ -757,7 +757,7 @@ where /// **Note:** If you pass a transactional executor and the transaction is /// rolled back, the returned task ID will not correspond to any persisted /// task. - pub async fn enqueue_using<'a, E>(&self, executor: E, input: I) -> Result + pub async fn enqueue_using<'a, E>(&self, executor: E, input: &I) -> Result where E: PgExecutor<'a>, { @@ -768,7 +768,7 @@ where /// queue's pool /// /// The given delay is added to the task's configured delay, if one is set. - pub async fn enqueue_after<'a, E>(&self, input: I, delay: Span) -> Result + pub async fn enqueue_after<'a, E>(&self, input: &I, delay: Span) -> Result where E: PgExecutor<'a>, { @@ -789,7 +789,7 @@ where pub async fn enqueue_after_using<'a, E>( &self, executor: E, - input: I, + input: &I, delay: Span, ) -> Result where @@ -798,7 +798,7 @@ where let job_input = self.first_job_input(input)?; let id = self .queue - .enqueue_after(executor, self, job_input, delay) + .enqueue_after(executor, self, &job_input, delay) .await?; Ok(id) @@ -823,7 +823,7 @@ where where E: PgExecutor<'a>, { - let job_input = self.first_job_input(input)?; + let job_input = self.first_job_input(&input)?; self.queue .schedule(executor, zoned_schedule, job_input) .await?; @@ -855,23 +855,16 @@ where let worker = Worker::new(queue.clone(), job.clone()); let scheduler = Scheduler::new(queue, job); - let worker_task = tokio::spawn(async move { worker.run().await }); - let scheduler_task = tokio::spawn(async move { scheduler.run().await }); - - tokio::select! { - res = worker_task => { - match res { - Ok(inner_res) => inner_res?, - Err(join_err) => return Err(Error::from(join_err)), - } - }, + let mut workers = JoinSet::new(); + workers.spawn(async move { worker.run().await.map_err(Error::from) }); + workers.spawn(async move { scheduler.run().await.map_err(Error::from) }); - res = scheduler_task => { - match res { - Ok(inner_res) => inner_res?, - Err(join_err) => return Err(Error::from(join_err)), - } - }, + while let Some(ret) = workers.join_next().await { + match ret { + Ok(Err(err)) => return Err(err), + Err(err) => return Err(Error::from(err)), + _ => continue, + } } Ok(()) @@ -885,7 +878,7 @@ where // TODO: stop method - fn first_job_input(&self, input: I) -> Result { + fn first_job_input(&self, input: &I) -> Result { let step_input = serde_json::to_value(input)?; let step_index = self.current_index.load(Ordering::SeqCst); let job_id = Ulid::new().into(); @@ -970,7 +963,7 @@ where }; self.queue - .enqueue_after(&mut *tx, self, next_job_input, delay) + .enqueue_after(&mut *tx, self, &next_job_input, delay) .await .map_err(|err| TaskError::Retryable(err.to_string()))?; @@ -1516,7 +1509,7 @@ mod tests { .build() .await?; - job.enqueue(()).await?; + job.enqueue(&()).await?; job.start(); @@ -1596,7 +1589,7 @@ mod tests { let input = Input { message: "Hello, world!".to_string(), }; - let task_id = job.enqueue(input.clone()).await?; + let task_id = job.enqueue(&input).await?; let Some(dequeued_task) = queue.dequeue(&pool).await? else { panic!("Task should exist"); @@ -1726,7 +1719,7 @@ mod tests { let input = Step1 { message: "Hello, world!".to_string(), }; - let task_id = job.enqueue(input).await?; + let task_id = job.enqueue(&input).await?; // Dequeue the first task. let Some(dequeued_task) = queue.dequeue(&pool).await? else { @@ -1846,7 +1839,7 @@ mod tests { let input = Step1 { message: "Hello, world!".to_string(), }; - let task_id = job.enqueue(input.clone()).await?; + let task_id = job.enqueue(&input).await?; // Dequeue the first task. let Some(dequeued_task) = queue.dequeue(&pool).await? else { diff --git a/src/lib.rs b/src/lib.rs index 99ddabb..a7acf26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,7 +86,7 @@ //! .await?; //! //! // Here we enqueue a new job to be processed later. -//! job.enqueue(WelcomeEmail { +//! job.enqueue(&WelcomeEmail { //! user_id: 42, //! email: "ferris@example.com".to_string(), //! name: "Ferris".to_string(), @@ -161,7 +161,7 @@ //! .await?; //! //! // Enqueue the job for the given order. -//! job.enqueue(GenerateReceipt { order_id: 42 }).await?; +//! job.enqueue(&GenerateReceipt { order_id: 42 }).await?; //! //! // Start processing enqueued jobs. //! job.start().await??; diff --git a/src/queue.rs b/src/queue.rs index a39839e..f7b3ffb 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -49,7 +49,7 @@ //! .pool(pool.clone()) //! .build() //! .await?; -//! # queue.enqueue(&pool, &MyTask, ()).await?; +//! # queue.enqueue(&pool, &MyTask, &()).await?; //! # Ok::<(), Box>(()) //! # }); //! # } @@ -104,7 +104,7 @@ //! # let task = MyTask; //! //! // Enqueue the task. -//! queue.enqueue(&pool, &task, ()).await?; +//! queue.enqueue(&pool, &task, &()).await?; //! //! if let Some(task) = queue.dequeue(&pool).await? { //! // Process the task here @@ -326,7 +326,7 @@ impl Queue { /// `std::time::Duration`. /// /// [ULID]: https://github.com/ulid/spec?tab=readme-ov-file#specification - pub async fn enqueue<'a, E>(&self, executor: E, task: &T, input: T::Input) -> Result + pub async fn enqueue<'a, E>(&self, executor: E, task: &T, input: &T::Input) -> Result where E: PgExecutor<'a>, { @@ -345,7 +345,7 @@ impl Queue { &self, executor: E, task: &T, - input: T::Input, + input: &T::Input, delay: Span, ) -> Result where @@ -368,7 +368,7 @@ impl Queue { &self, executor: E, task: &T, - input: T::Input, + input: &T::Input, delay: Span, ) -> Result where @@ -1061,7 +1061,7 @@ mod tests { let input = serde_json::json!({ "key": "value" }); let task = TestTask; - let task_id = queue.enqueue(&pool, &task, input.clone()).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Query the database to verify the task was enqueued let dequeued_task = sqlx::query!( @@ -1115,7 +1115,7 @@ mod tests { let task = TestTask; let task_id = queue - .enqueue_after(&pool, &task, input.clone(), 5.minutes()) + .enqueue_after(&pool, &task, &input, 5.minutes()) .await?; // Check the delay @@ -1149,7 +1149,7 @@ mod tests { let task = TestTask; // Enqueue a task - let task_id = queue.enqueue(&pool, &task, input.clone()).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Dequeue the task let dequeued_task = queue.dequeue(&pool).await?; @@ -1210,7 +1210,7 @@ mod tests { // Enqueue multiple tasks for _ in 0..5 { - queue.enqueue(&pool, &task, input.clone()).await?; + queue.enqueue(&pool, &task, &input).await?; } // Simulate concurrent dequeues @@ -1267,7 +1267,7 @@ mod tests { let task = TestTask; // Enqueue a task - let task_id = queue.enqueue(&pool, &task, input).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Mark the task as in progress queue.mark_task_in_progress(&pool, task_id).await?; @@ -1299,7 +1299,7 @@ mod tests { let task = TestTask; // Enqueue a task - let task_id = queue.enqueue(&pool, &task, input).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Reschedule the task for retry let retry_count = 1; @@ -1348,7 +1348,7 @@ mod tests { let task = TestTask; // Enqueue a task - let task_id = queue.enqueue(&pool, &task, input).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Cancel the task queue.mark_task_cancelled(&pool, task_id).await?; @@ -1380,7 +1380,7 @@ mod tests { let task = TestTask; // Enqueue a task - let task_id = queue.enqueue(&pool, &task, input).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Mark the task as succeeded queue.mark_task_succeeded(&pool, task_id).await?; @@ -1466,7 +1466,7 @@ mod tests { let task = TestTask; // Enqueue a task - let task_id = queue.enqueue(&pool, &task, input).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Mark the task as failed queue.mark_task_failed(&pool, task_id).await?; @@ -1499,7 +1499,7 @@ mod tests { let task = TestTask; // Enqueue a task - let task_id = queue.enqueue(&pool, &task, input).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Update task failure details let retry_count = 2; @@ -1540,7 +1540,7 @@ mod tests { let task = TestTask; // Enqueue a task - let task_id = queue.enqueue(&pool, &task, input).await?; + let task_id = queue.enqueue(&pool, &task, &input).await?; // Move the task to DLQ queue.move_task_to_dlq(&pool, task_id, "test_dlq").await?; diff --git a/src/scheduler.rs b/src/scheduler.rs index 819ad37..d55d882 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -1,14 +1,13 @@ -use std::{ - future, result::Result as StdResult, str::FromStr, sync::Arc, time::Duration as StdDuration, -}; +use std::{result::Result as StdResult, str::FromStr, sync::Arc, time::Duration as StdDuration}; use jiff::{tz::TimeZone, Span, ToSpan, Zoned}; use jiff_cron::Schedule; -use sqlx::postgres::PgAdvisoryLock; +use sqlx::postgres::{PgAdvisoryLock, PgListener}; +use tokio_util::sync::CancellationToken; use tracing::instrument; use crate::{ - queue::{try_acquire_advisory_lock, Error as QueueError}, + queue::{try_acquire_advisory_lock, Error as QueueError, SHUTDOWN_CHANNEL}, Queue, Task, }; @@ -19,6 +18,9 @@ pub enum Error { #[error(transparent)] Queue(#[from] QueueError), + #[error(transparent)] + Database(#[from] sqlx::Error), + #[error(transparent)] Jiff(#[from] jiff::Error), @@ -44,6 +46,9 @@ pub struct Scheduler { queue: Queue, queue_lock: PgAdvisoryLock, task: Arc, + + // When this token is cancelled the queue has been shutdown. + shutdown_token: CancellationToken, } impl Scheduler { @@ -54,6 +59,7 @@ impl Scheduler { queue, queue_lock, task: Arc::new(task), + shutdown_token: CancellationToken::new(), } } @@ -65,29 +71,60 @@ impl Scheduler { /// Runs the scheduler in a loop, sleeping for the given period per /// iteration. pub async fn run_every(&self, period: Span) -> Result { - let conn = self.queue.pool.acquire().await.map_err(QueueError::from)?; + let conn = self.queue.pool.acquire().await?; let Some(_guard) = try_acquire_advisory_lock(conn, &self.queue_lock).await? else { - // We can't acquire the lock, so we'll return a future that waits forever. - return future::pending().await; + tracing::debug!("Scheduler could not acquire lock, exiting"); + return Ok(()); }; - let mut interval = tokio::time::interval(period.try_into()?); - interval.tick().await; + let Some((mut zoned_schedule, input)) = self.queue.task_schedule(&self.queue.pool).await? + else { + // No schedule configured, so we'll exit. + return Ok(()); + }; + + // Set up a listener for shutdown notifications + let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?; + shutdown_listener.listen(SHUTDOWN_CHANNEL).await?; + + let mut polling_interval = tokio::time::interval(period.try_into()?); loop { - // TODO: It would be preferrable to not check the schedule every second and wait - // for a NOTIFY instead. - if let Some((zoned_schedule, input)) = - self.queue.task_schedule(&self.queue.pool).await? - { - // TODO: If we were waiting for a NOTIFY or timeout, we could keep processing - // the same schedule without fetching from the database. - if let Some(until_next) = zoned_schedule.into_iter().next() { - self.process_next_schedule(until_next, input).await? + tokio::select! { + notify_shutdown = shutdown_listener.recv() => { + match notify_shutdown { + Ok(_) => { + self.shutdown_token.cancel(); + }, + + Err(err) => { + tracing::error!(%err, "Postgres shutdown notification error"); + } + } + } + + _ = self.shutdown_token.cancelled() => { + break + } + + _ = polling_interval.tick() => { + self.trigger_schedule_processing(&mut zoned_schedule, &input).await? } } + } + + Ok(()) + } - interval.tick().await; + async fn trigger_schedule_processing( + &self, + zoned_schedule: &mut ZonedSchedule, + input: &T::Input, + ) -> Result { + if let Some(until_next) = zoned_schedule.into_iter().next() { + self.process_next_schedule(until_next, input).await?; } + + Ok(()) } #[instrument( @@ -99,7 +136,7 @@ impl Scheduler { ), err )] - async fn process_next_schedule(&self, until_next: StdDuration, input: T::Input) -> Result { + async fn process_next_schedule(&self, until_next: StdDuration, input: &T::Input) -> Result { tracing::debug!(?until_next, "Sleeping until the next scheduled enqueue"); tokio::time::sleep(until_next).await; diff --git a/src/worker.rs b/src/worker.rs index 4480989..fcf9621 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -595,7 +595,7 @@ mod tests { // Enqueue a task. let task = TestTask; - queue.enqueue(&pool, &task, ()).await?; + queue.enqueue(&pool, &task, &()).await?; assert!(queue.dequeue(&pool).await?.is_some()); // Process the task. @@ -623,7 +623,7 @@ mod tests { let worker = Worker::new(queue.clone(), task.clone()); // Enqueue the task - let task_id = queue.enqueue(&pool, &worker.task, ()).await?; + let task_id = queue.enqueue(&pool, &worker.task, &()).await?; // Process the task multiple times to simulate retries for retries in 0..3 { @@ -689,7 +689,7 @@ mod tests { // Enqueue some tasks now that the worker is listening for _ in 0..5 { - queue.enqueue(&pool, &LongRunningTask, ()).await?; + queue.enqueue(&pool, &LongRunningTask, &()).await?; } // Initiate graceful shutdown @@ -711,7 +711,7 @@ mod tests { assert_eq!(succeeded, Some(5)); // New tasks shouldn't be processed - queue.enqueue(&pool, &LongRunningTask, ()).await?; + queue.enqueue(&pool, &LongRunningTask, &()).await?; // Wait to ensure a worker would have seen the new task if one were processing tokio::time::sleep(StdDuration::from_secs(1)).await;