Skip to content

Commit

Permalink
use channels to process ready tasks
Browse files Browse the repository at this point in the history
This introduces a trigger which will notify a channel when tasks are
inserted or updated in a "pending" state. Workers now listen on this
channel and if their queue is named as the one the notification is for
begin processing the next task.

Because channel notifications can fail, we retain the polling mechanism
only now it's set to a default polling interval of one minute. This can
be configured to be shorter or longer as well.

Closes #19
  • Loading branch information
maxcountryman committed Oct 24, 2024
1 parent a898a8d commit 23adffa
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 80 deletions.
19 changes: 19 additions & 0 deletions migrations/20241024174106_1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- function to notify about task changes
create or replace function underway.task_change_notify()
returns trigger as $$
begin
if (new.state = 'pending') then
perform pg_notify('task_change', json_build_object(
'task_queue_name', new.task_queue_name
)::text);
end if;

return new;
end;
$$ language plpgsql;

-- trigger that calls the function after task changes
create trigger task_changed
after insert or update on underway.task
for each row
execute procedure underway.task_change_notify();
218 changes: 138 additions & 80 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ use std::sync::{
};

use jiff::{Span, ToSpan};
use serde::Deserialize;
use sqlx::{
postgres::{types::PgInterval, PgListener},
postgres::{types::PgInterval, PgListener, PgNotification},
Acquire, PgConnection,
};
use tokio::{sync::Semaphore, task::JoinSet};
Expand All @@ -136,30 +137,6 @@ use crate::{
};
pub(crate) type Result = std::result::Result<(), Error>;

/// A worker that's generic over the task it processes.
#[derive(Debug)]
pub struct Worker<T: Task> {
queue: Queue<T>,
task: Arc<T>,

// Limits the number of concurrent `Task::execute` invocations this worker will be allowed.
concurrency_limit: usize,

// Indicates the underlying queue has received a shutdown signal.
queue_shutdown: Arc<AtomicBool>,
}

impl<T: Task> Clone for Worker<T> {
fn clone(&self) -> Self {
Self {
queue: self.queue.clone(),
task: self.task.clone(),
concurrency_limit: self.concurrency_limit,
queue_shutdown: self.queue_shutdown.clone(),
}
}
}

/// Worker errors.
#[derive(Debug, thiserror::Error)]
pub enum Error {
Expand All @@ -185,6 +162,30 @@ pub enum Error {
Jiff(#[from] jiff::Error),
}

/// A worker that's generic over the task it processes.
#[derive(Debug)]
pub struct Worker<T: Task> {
queue: Queue<T>,
task: Arc<T>,

// Limits the number of concurrent `Task::execute` invocations this worker will be allowed.
concurrency_limit: usize,

// Indicates the underlying queue has received a shutdown signal.
queue_shutdown: Arc<AtomicBool>,
}

impl<T: Task> Clone for Worker<T> {
fn clone(&self) -> Self {
Self {
queue: self.queue.clone(),
task: self.task.clone(),
concurrency_limit: self.concurrency_limit,
queue_shutdown: self.queue_shutdown.clone(),
}
}
}

impl<T: Task + Sync> Worker<T> {
/// Creates a new worker with the given queue and task.
pub fn new(queue: Queue<T>, task: T) -> Self {
Expand All @@ -206,92 +207,143 @@ impl<T: Task + Sync> Worker<T> {

/// Runs the worker, processing tasks as they become available.
///
/// Tasks are processed via polling in a loop. A one-second sleep occurs
/// between polls.
/// Tasks are processed via a subscription to a Postgres channel and polling
/// in a loop. A one-minute sleep occurs between polls.
pub async fn run(&self) -> Result {
self.run_every(1.second()).await
self.run_every(1.minute()).await
}

/// Same as `run` but allows for the configuration of the span between
/// Same as `run` but allows for the configuration of the delay between
/// polls.
pub async fn run_every(&self, span: Span) -> Result {
let mut interval = tokio::time::interval(span.try_into()?);
pub async fn run_every(&self, period: Span) -> Result {
let mut interval = tokio::time::interval(period.try_into()?);

// Set up a listener for shutdown notifications
let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?;
shutdown_listener.listen(SHUTDOWN_CHANNEL).await?;

// Set up a listener for task change notifications
let mut task_change_listener = PgListener::connect_with(&self.queue.pool).await?;
task_change_listener.listen("task_change").await?;

let concurrency_limit = Arc::new(Semaphore::new(self.concurrency_limit));
let mut processing_tasks = JoinSet::new();

loop {
tokio::select! {
notify_shutdown = shutdown_listener.recv() => {
if let Err(err) = notify_shutdown {
tracing::error!(%err, "Postgres notification error");
continue;
}
match notify_shutdown {
Ok(_) => self.handle_shutdown(&mut processing_tasks).await?,
Err(err) => {
tracing::error!(%err, "Postgres shutdown notification error");
continue
}
};

self.queue_shutdown.store(true, Ordering::SeqCst);
break;
},

let task_timeout = self.task.timeout();
notify_task_change = task_change_listener.recv() => {
match notify_task_change {
Ok(task_change) => self.handle_task_change(task_change, concurrency_limit.clone(), &mut processing_tasks).await?,
Err(err) => {
tracing::error!(%err, "Postgres task change notification error");
continue;
}
};

tracing::info!(
task.timeout = ?task_timeout,
"Waiting for all processing tasks or timeout"
);
},

// Try to join all the processing tasks before the task timeout.
let shutdown_result = tokio::time::timeout(
task_timeout.try_into()?,
async {
while let Some(res) = processing_tasks.join_next().await {
if let Err(err) = res {
tracing::error!(%err, "A processing task failed during shutdown");
}
}
}
// Polling fallback.
_ = interval.tick() => {
self.trigger_task_processing(
concurrency_limit.clone(),
&mut processing_tasks
).await;
}
}
}

match shutdown_result {
Ok(_) => {
tracing::debug!("All processing tasks completed gracefully");
},
Err(_) => {
let remaining_tasks = processing_tasks.len();
tracing::warn!(remaining_tasks, "Reached task timeout before all tasks completed");
},
}
Ok(())
}

break;
},
async fn handle_shutdown(&self, processing_tasks: &mut JoinSet<()>) -> Result {
self.queue_shutdown.store(true, Ordering::SeqCst);

_ = interval.tick() => {
if self.queue_shutdown.load(Ordering::SeqCst) {
tracing::info!("Queue is shutdown so no new tasks will be processed");
break;
}
let task_timeout = self.task.timeout();

let permit = concurrency_limit.clone().acquire_owned().await.expect("Concurrency limit semaphore should be open");
processing_tasks.spawn({
// TODO: Rather than clone the worker, we could have a separate type that
// owns task processing.
let worker = self.clone();

async move {
if let Err(err) = worker.process_next_task().await {
tracing::error!(%err, "Error processing next task");
}
drop(permit);
}
});
tracing::info!(
task.timeout = ?task_timeout,
"Waiting for all processing tasks or timeout"
);

// Wait for processing tasks to complete or timeout
let shutdown_result = tokio::time::timeout(task_timeout.try_into()?, async {
while let Some(res) = processing_tasks.join_next().await {
if let Err(err) = res {
tracing::error!(%err, "A processing task failed during shutdown");
}
}
})
.await;

match shutdown_result {
Ok(_) => {
tracing::debug!("All processing tasks completed gracefully");
}
Err(_) => {
let remaining_tasks = processing_tasks.len();
tracing::warn!(
remaining_tasks,
"Reached task timeout before all tasks completed"
);
}
}

Ok(())
}

async fn handle_task_change(
&self,
task_change: PgNotification,
concurrency_limit: Arc<Semaphore>,
processing_tasks: &mut JoinSet<()>,
) -> Result {
let payload = task_change.payload();
let decoded: TaskChange = serde_json::from_str(payload).map_err(|err| {
tracing::error!(%err, "Invalid task change payload; ignoring");
err
})?;

if decoded.queue_name == self.queue.name {
self.trigger_task_processing(concurrency_limit, processing_tasks)
.await;
}

Ok(())
}

async fn trigger_task_processing(
&self,
concurrency_limit: Arc<Semaphore>,
processing_tasks: &mut JoinSet<()>,
) {
let Ok(permit) = concurrency_limit.clone().try_acquire_owned() else {
tracing::debug!("Concurrency limit reached");
return;
};

processing_tasks.spawn({
let worker = self.clone();
async move {
if let Err(err) = worker.process_next_task().await {
tracing::error!(%err, "Error processing next task");
}
drop(permit);
}
});
}

/// Processes the next available task in the queue.
///
/// When a task is found, its execute method will be invoked with the
Expand Down Expand Up @@ -448,6 +500,12 @@ impl<T: Task + Sync> Worker<T> {
}
}

#[derive(Debug, Deserialize)]
struct TaskChange {
#[serde(rename = "task_queue_name")]
queue_name: String,
}

pub(crate) fn pg_interval_to_span(
PgInterval {
months,
Expand Down

0 comments on commit 23adffa

Please sign in to comment.