diff --git a/Cargo.toml b/Cargo.toml index 80f3ee5..c9b1a15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ tokio = { version = "1.40.0", features = [ tracing = { version = "0.1.40", features = ["log"] } ulid = { version = "1.1.3", features = ["uuid"] } uuid = { version = "1.10.0", features = ["v4"] } +num_cpus = "1.16.0" [dev-dependencies] futures = "0.3.30" diff --git a/src/worker.rs b/src/worker.rs index e5e6d76..c434272 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -126,7 +126,7 @@ use sqlx::{ postgres::{types::PgInterval, PgListener}, PgConnection, }; -use tokio::sync::{Notify, Semaphore}; +use tokio::{sync::Semaphore, task::JoinSet}; use tracing::instrument; use crate::{ @@ -142,14 +142,11 @@ pub struct Worker { queue: Queue, task: Arc, + // Limits the number of concurrent `Task::execute` invocations this worker will be allowed. + concurrency_limit: usize, + // Indicates the underlying queue has received a shutdown signal. queue_shutdown: Arc, - - // Indicates that this worker is processing a task. - processing: Arc, - - // Notifies when a task is done processing. - processing_done: Arc, } impl Clone for Worker { @@ -157,11 +154,8 @@ impl Clone for Worker { Self { queue: self.queue.clone(), task: Arc::clone(&self.task), + concurrency_limit: self.concurrency_limit, queue_shutdown: self.queue_shutdown.clone(), - - // Clones don't share processing state. - processing: Arc::new(false.into()), - processing_done: Arc::new(Notify::new()), } } } @@ -199,9 +193,8 @@ where Self { queue: job.queue.clone(), task: Arc::new(job), + concurrency_limit: num_cpus::get(), queue_shutdown: Arc::new(false.into()), - processing: Arc::new(false.into()), - processing_done: Arc::new(Notify::new()), } } } @@ -214,9 +207,8 @@ where Self { queue: job.queue.clone(), task: Arc::new(job.to_owned()), + concurrency_limit: num_cpus::get(), queue_shutdown: Arc::new(false.into()), - processing: Arc::new(false.into()), - processing_done: Arc::new(Notify::new()), } } } @@ -227,12 +219,17 @@ impl Worker { Self { queue, task: Arc::new(task), + concurrency_limit: num_cpus::get(), queue_shutdown: Arc::new(false.into()), - processing: Arc::new(false.into()), - processing_done: Arc::new(Notify::new()), } } + /// Sets the concurrency limit for this worker. + pub fn concurrency_limit(mut self, concurrency_limit: usize) -> Self { + self.concurrency_limit = concurrency_limit; + self + } + /// Runs the worker, processing tasks as they become available. /// /// Tasks are processed via polling in a loop. A one-second sleep occurs @@ -250,8 +247,8 @@ impl Worker { let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?; shutdown_listener.listen(SHUTDOWN_CHANNEL).await?; - // TODO: Concurrent limit config. - let permits = Arc::new(Semaphore::new(1)); + let concurrency_limit = Arc::new(Semaphore::new(self.concurrency_limit)); + let mut processing_tasks = JoinSet::new(); loop { tokio::select! { @@ -263,25 +260,35 @@ impl Worker { self.queue_shutdown.store(true, Ordering::SeqCst); - if self.processing.load(Ordering::SeqCst) { - tracing::info!( - task.timeout = ?self.task.timeout(), - "Worker is in-progress with a task, waiting for done notification or task timeout" - ); - - // Wait for either task processing to be done or the task timeout, - // whichever is first. - tokio::select! { - _ = self.processing_done.notified() => { - tracing::debug!("Received done notification"); - }, - - _ = tokio::time::sleep(self.task.timeout().try_into().unwrap()) => { - tracing::debug!("Slept until task timeout"); - }, + let task_timeout = self.task.timeout(); + + tracing::info!( + task.timeout = ?task_timeout, + "Waiting for all processing tasks or timeout" + ); + + // Try to join all the processing tasks before the task timeout. + let shutdown_result = tokio::time::timeout( + task_timeout.try_into()?, + async { + while let Some(res) = processing_tasks.join_next().await { + if let Err(err) = res { + tracing::error!(%err, "A processing task failed during shutdown"); + } + } } + ).await; + + match shutdown_result { + Ok(_) => { + tracing::debug!("All processing tasks completed gracefully"); + }, + Err(_) => { + tracing::debug!("Reached task timeout before all tasks completed"); + }, } + break; }, _ = interval.tick() => { @@ -290,16 +297,17 @@ impl Worker { break; } - let permits = Arc::clone(&permits); - - // Cancellation safety: spawned tasks are managed by Tokio's runtime and will not - // be dropped when the interval.tick future is cancelled. - tokio::spawn({ + let permit = concurrency_limit.clone().acquire_owned().await.expect("Concurrency limit semaphore should be open"); + processing_tasks.spawn({ + // TODO: Rather than clone the worker, we could have a separate type that + // owns task processing. let worker = self.clone(); + async move { - let _permit = permits.acquire().await.expect("Semaphore should be open"); - worker.process_next_task().await?; - Ok::<(), Error>(()) + if let Err(err) = worker.process_next_task().await { + tracing::error!(%err, "Error processing next task"); + } + drop(permit); } }); } @@ -327,8 +335,6 @@ impl Worker { let mut tx = self.queue.pool.begin().await?; if let Some(task_row) = self.queue.dequeue(&mut tx).await? { - self.processing.store(true, Ordering::SeqCst); - let task_id = task_row.id; tracing::Span::current().record("task.id", task_id.as_hyphenated().to_string()); @@ -365,9 +371,6 @@ impl Worker { } tx.commit().await?; - - self.processing.store(false, Ordering::SeqCst); - self.processing_done.notify_waiters(); } Ok(())