Skip to content

Commit

Permalink
refactor scheduler
Browse files Browse the repository at this point in the history
This refactors the scheduler to be more like the worker: e.g. it now
manages a cancellation token and uses it to manage the shutdown state.

Again similarly to the worker, it listens on the same shutdown channel
and upon a notify from postgres will cancel its cancellation token
thereby stopping the scheduler.

Coupled with these changes is a breaking change to the queue interface:
instead of taking ownership of input, enqueue takes a reference. A
similar change should be made for the schedule methods, but that is left
out of this patch intentionally.
  • Loading branch information
maxcountryman committed Oct 26, 2024
1 parent 6127c9a commit e48a4ea
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 81 deletions.
2 changes: 1 addition & 1 deletion examples/basic/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// Enqueue a job task.
let task_id = job
.enqueue(WelcomeEmail {
.enqueue(&WelcomeEmail {
user_id: 42,
email: "[email protected]".to_string(),
name: "Ferris".to_string(),
Expand Down
4 changes: 2 additions & 2 deletions examples/multitask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl WelcomeEmailTask {
// This ensures our task-specific configuration is applied.
let welcome_email_task = self.into();
queue
.enqueue(pool, &welcome_email_task, TaskInput::WelcomeEmail(input))
.enqueue(pool, &welcome_email_task, &TaskInput::WelcomeEmail(input))
.await
}
}
Expand Down Expand Up @@ -63,7 +63,7 @@ impl OrderTask {
// This ensures our task-specific configuration is applied.
let order_task = self.into();
queue
.enqueue(pool, &order_task, TaskInput::Order(input))
.enqueue(pool, &order_task, &TaskInput::Order(input))
.await
}
}
Expand Down
2 changes: 1 addition & 1 deletion examples/step/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await?;

// Enqueue the first step.
job.enqueue(Start { n: 42 }).await?;
job.enqueue(&Start { n: 42 }).await?;

// Run the job worker.
job.run().await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/tracing/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// Enqueue a job task.
let task_id = job
.enqueue(WelcomeEmail {
.enqueue(&WelcomeEmail {
user_id: 42,
email: "[email protected]".to_string(),
name: "Ferris".to_string(),
Expand Down
59 changes: 26 additions & 33 deletions src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@
//! .await?;
//!
//! // Enqueue a new job with the given input `()`.
//! job.enqueue(()).await?;
//! job.enqueue(&()).await?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! # });
//! # }
Expand Down Expand Up @@ -381,7 +381,7 @@
//! .queue(queue)
//! .build();
//!
//! job.enqueue(()).await?;
//! job.enqueue(&()).await?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! # });
//! # }
Expand Down Expand Up @@ -419,7 +419,7 @@
//! .await?;
//!
//! // Enqueue a new job with a slightly more interesting value.
//! job.enqueue(Input {
//! job.enqueue(&Input {
//! bucket_name: "my_bucket".to_string(),
//! })
//! .await?;
Expand Down Expand Up @@ -470,7 +470,7 @@
//! #
//!
//! // Enqueue using a transaction that we supply.
//! job.enqueue_using(&mut *tx, ()).await?;
//! job.enqueue_using(&mut *tx, &()).await?;
//!
//! # /*
//! /* ...And more intervening logic involving `tx`. */
Expand Down Expand Up @@ -637,7 +637,7 @@ use jiff::Span;
use sealed::JobState;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sqlx::{PgExecutor, PgPool, Postgres, Transaction};
use tokio::task::JoinHandle;
use tokio::task::{JoinHandle, JoinSet};
use tracing::instrument;
use ulid::Ulid;

Expand Down Expand Up @@ -744,7 +744,7 @@ where
}

/// Enqueue the job using a connection from the queue's pool.
pub async fn enqueue(&self, input: I) -> Result<TaskId> {
pub async fn enqueue(&self, input: &I) -> Result<TaskId> {
let mut conn = self.queue.pool.acquire().await?;
self.enqueue_using(&mut *conn, input).await
}
Expand All @@ -757,7 +757,7 @@ where
/// **Note:** If you pass a transactional executor and the transaction is
/// rolled back, the returned task ID will not correspond to any persisted
/// task.
pub async fn enqueue_using<'a, E>(&self, executor: E, input: I) -> Result<TaskId>
pub async fn enqueue_using<'a, E>(&self, executor: E, input: &I) -> Result<TaskId>
where
E: PgExecutor<'a>,
{
Expand All @@ -768,7 +768,7 @@ where
/// queue's pool
///
/// The given delay is added to the task's configured delay, if one is set.
pub async fn enqueue_after<'a, E>(&self, input: I, delay: Span) -> Result<TaskId>
pub async fn enqueue_after<'a, E>(&self, input: &I, delay: Span) -> Result<TaskId>
where
E: PgExecutor<'a>,
{
Expand All @@ -789,7 +789,7 @@ where
pub async fn enqueue_after_using<'a, E>(
&self,
executor: E,
input: I,
input: &I,
delay: Span,
) -> Result<TaskId>
where
Expand All @@ -798,7 +798,7 @@ where
let job_input = self.first_job_input(input)?;
let id = self
.queue
.enqueue_after(executor, self, job_input, delay)
.enqueue_after(executor, self, &job_input, delay)
.await?;

Ok(id)
Expand All @@ -823,7 +823,7 @@ where
where
E: PgExecutor<'a>,
{
let job_input = self.first_job_input(input)?;
let job_input = self.first_job_input(&input)?;
self.queue
.schedule(executor, zoned_schedule, job_input)
.await?;
Expand Down Expand Up @@ -855,23 +855,16 @@ where
let worker = Worker::new(queue.clone(), job.clone());
let scheduler = Scheduler::new(queue, job);

let worker_task = tokio::spawn(async move { worker.run().await });
let scheduler_task = tokio::spawn(async move { scheduler.run().await });

tokio::select! {
res = worker_task => {
match res {
Ok(inner_res) => inner_res?,
Err(join_err) => return Err(Error::from(join_err)),
}
},
let mut workers = JoinSet::new();
workers.spawn(async move { worker.run().await.map_err(Error::from) });
workers.spawn(async move { scheduler.run().await.map_err(Error::from) });

res = scheduler_task => {
match res {
Ok(inner_res) => inner_res?,
Err(join_err) => return Err(Error::from(join_err)),
}
},
while let Some(ret) = workers.join_next().await {
match ret {
Ok(Err(err)) => return Err(err),
Err(err) => return Err(Error::from(err)),
_ => continue,
}
}

Ok(())
Expand All @@ -885,7 +878,7 @@ where

// TODO: stop method

fn first_job_input(&self, input: I) -> Result<JobState> {
fn first_job_input(&self, input: &I) -> Result<JobState> {
let step_input = serde_json::to_value(input)?;
let step_index = self.current_index.load(Ordering::SeqCst);
let job_id = Ulid::new().into();
Expand Down Expand Up @@ -970,7 +963,7 @@ where
};

self.queue
.enqueue_after(&mut *tx, self, next_job_input, delay)
.enqueue_after(&mut *tx, self, &next_job_input, delay)
.await
.map_err(|err| TaskError::Retryable(err.to_string()))?;

Expand Down Expand Up @@ -1516,7 +1509,7 @@ mod tests {
.build()
.await?;

job.enqueue(()).await?;
job.enqueue(&()).await?;

job.start();

Expand Down Expand Up @@ -1596,7 +1589,7 @@ mod tests {
let input = Input {
message: "Hello, world!".to_string(),
};
let task_id = job.enqueue(input.clone()).await?;
let task_id = job.enqueue(&input).await?;

let Some(dequeued_task) = queue.dequeue(&pool).await? else {
panic!("Task should exist");
Expand Down Expand Up @@ -1726,7 +1719,7 @@ mod tests {
let input = Step1 {
message: "Hello, world!".to_string(),
};
let task_id = job.enqueue(input).await?;
let task_id = job.enqueue(&input).await?;

// Dequeue the first task.
let Some(dequeued_task) = queue.dequeue(&pool).await? else {
Expand Down Expand Up @@ -1846,7 +1839,7 @@ mod tests {
let input = Step1 {
message: "Hello, world!".to_string(),
};
let task_id = job.enqueue(input.clone()).await?;
let task_id = job.enqueue(&input).await?;

// Dequeue the first task.
let Some(dequeued_task) = queue.dequeue(&pool).await? else {
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
//! .await?;
//!
//! // Here we enqueue a new job to be processed later.
//! job.enqueue(WelcomeEmail {
//! job.enqueue(&WelcomeEmail {
//! user_id: 42,
//! email: "[email protected]".to_string(),
//! name: "Ferris".to_string(),
Expand Down Expand Up @@ -161,7 +161,7 @@
//! .await?;
//!
//! // Enqueue the job for the given order.
//! job.enqueue(GenerateReceipt { order_id: 42 }).await?;
//! job.enqueue(&GenerateReceipt { order_id: 42 }).await?;
//!
//! // Start processing enqueued jobs.
//! job.start().await??;
Expand Down
32 changes: 16 additions & 16 deletions src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
//! .pool(pool.clone())
//! .build()
//! .await?;
//! # queue.enqueue(&pool, &MyTask, ()).await?;
//! # queue.enqueue(&pool, &MyTask, &()).await?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! # });
//! # }
Expand Down Expand Up @@ -104,7 +104,7 @@
//! # let task = MyTask;
//!
//! // Enqueue the task.
//! queue.enqueue(&pool, &task, ()).await?;
//! queue.enqueue(&pool, &task, &()).await?;
//!
//! if let Some(task) = queue.dequeue(&pool).await? {
//! // Process the task here
Expand Down Expand Up @@ -326,7 +326,7 @@ impl<T: Task> Queue<T> {
/// `std::time::Duration`.
///
/// [ULID]: https://github.com/ulid/spec?tab=readme-ov-file#specification
pub async fn enqueue<'a, E>(&self, executor: E, task: &T, input: T::Input) -> Result<TaskId>
pub async fn enqueue<'a, E>(&self, executor: E, task: &T, input: &T::Input) -> Result<TaskId>
where
E: PgExecutor<'a>,
{
Expand All @@ -345,7 +345,7 @@ impl<T: Task> Queue<T> {
&self,
executor: E,
task: &T,
input: T::Input,
input: &T::Input,
delay: Span,
) -> Result<TaskId>
where
Expand All @@ -368,7 +368,7 @@ impl<T: Task> Queue<T> {
&self,
executor: E,
task: &T,
input: T::Input,
input: &T::Input,
delay: Span,
) -> Result<TaskId>
where
Expand Down Expand Up @@ -1061,7 +1061,7 @@ mod tests {
let input = serde_json::json!({ "key": "value" });
let task = TestTask;

let task_id = queue.enqueue(&pool, &task, input.clone()).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Query the database to verify the task was enqueued
let dequeued_task = sqlx::query!(
Expand Down Expand Up @@ -1115,7 +1115,7 @@ mod tests {
let task = TestTask;

let task_id = queue
.enqueue_after(&pool, &task, input.clone(), 5.minutes())
.enqueue_after(&pool, &task, &input, 5.minutes())
.await?;

// Check the delay
Expand Down Expand Up @@ -1149,7 +1149,7 @@ mod tests {
let task = TestTask;

// Enqueue a task
let task_id = queue.enqueue(&pool, &task, input.clone()).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Dequeue the task
let dequeued_task = queue.dequeue(&pool).await?;
Expand Down Expand Up @@ -1210,7 +1210,7 @@ mod tests {

// Enqueue multiple tasks
for _ in 0..5 {
queue.enqueue(&pool, &task, input.clone()).await?;
queue.enqueue(&pool, &task, &input).await?;
}

// Simulate concurrent dequeues
Expand Down Expand Up @@ -1267,7 +1267,7 @@ mod tests {
let task = TestTask;

// Enqueue a task
let task_id = queue.enqueue(&pool, &task, input).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Mark the task as in progress
queue.mark_task_in_progress(&pool, task_id).await?;
Expand Down Expand Up @@ -1299,7 +1299,7 @@ mod tests {
let task = TestTask;

// Enqueue a task
let task_id = queue.enqueue(&pool, &task, input).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Reschedule the task for retry
let retry_count = 1;
Expand Down Expand Up @@ -1348,7 +1348,7 @@ mod tests {
let task = TestTask;

// Enqueue a task
let task_id = queue.enqueue(&pool, &task, input).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Cancel the task
queue.mark_task_cancelled(&pool, task_id).await?;
Expand Down Expand Up @@ -1380,7 +1380,7 @@ mod tests {
let task = TestTask;

// Enqueue a task
let task_id = queue.enqueue(&pool, &task, input).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Mark the task as succeeded
queue.mark_task_succeeded(&pool, task_id).await?;
Expand Down Expand Up @@ -1466,7 +1466,7 @@ mod tests {
let task = TestTask;

// Enqueue a task
let task_id = queue.enqueue(&pool, &task, input).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Mark the task as failed
queue.mark_task_failed(&pool, task_id).await?;
Expand Down Expand Up @@ -1499,7 +1499,7 @@ mod tests {
let task = TestTask;

// Enqueue a task
let task_id = queue.enqueue(&pool, &task, input).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Update task failure details
let retry_count = 2;
Expand Down Expand Up @@ -1540,7 +1540,7 @@ mod tests {
let task = TestTask;

// Enqueue a task
let task_id = queue.enqueue(&pool, &task, input).await?;
let task_id = queue.enqueue(&pool, &task, &input).await?;

// Move the task to DLQ
queue.move_task_to_dlq(&pool, task_id, "test_dlq").await?;
Expand Down
Loading

0 comments on commit e48a4ea

Please sign in to comment.