Skip to content

Commit dfd5e10

Browse files
committed
Add shutdown notifier to task scheduler
1 parent d09afbc commit dfd5e10

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

casper-server/src/context.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ impl AppContextInner {
141141

142142
// Start task scheduler
143143
let max_background_tasks = self.config.main.max_background_tasks;
144-
lua::tasks::start_task_scheduler(&lua, max_background_tasks);
144+
lua::tasks::start_task_scheduler(lua, max_background_tasks);
145145

146146
// Enable sandboxing before loading user code
147147
lua.sandbox(true)?;

casper-server/src/lua/tasks.rs

+30-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
use std::rc::Rc;
12
use std::result::Result as StdResult;
23
use std::sync::atomic::{AtomicU64, Ordering};
34

45
use mlua::{
56
AnyUserData, ExternalError, ExternalResult, Function, Lua, Result, Table, UserData, Value,
67
};
78
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
8-
use tokio::sync::oneshot;
9+
use tokio::sync::{oneshot, watch};
910
use tokio::task::JoinHandle;
1011
use tokio::time::{Duration, Instant};
1112
use tracing::warn;
@@ -30,8 +31,12 @@ struct TaskHandle {
3031
join_handle_rx: Option<oneshot::Receiver<TaskJoinHandle>>,
3132
}
3233

34+
#[derive(Clone, Copy)]
3335
struct MaxBackgroundTasks(Option<u64>);
3436

37+
#[derive(Clone)]
38+
struct ShutdownNotifier(watch::Sender<bool>);
39+
3540
// Global task identifier
3641
static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);
3742

@@ -71,7 +76,7 @@ impl UserData for TaskHandle {
7176
}
7277

7378
fn spawn_task(lua: &Lua, arg: Value) -> Result<StdResult<TaskHandle, String>> {
74-
let max_background_tasks = lua.app_data_ref::<MaxBackgroundTasks>().unwrap();
79+
let max_background_tasks = *lua.app_data_ref::<MaxBackgroundTasks>().unwrap();
7580
let current_tasks = tasks_counter_get!();
7681

7782
if let Some(max_tasks) = max_background_tasks.0 {
@@ -128,27 +133,38 @@ fn spawn_task(lua: &Lua, arg: Value) -> Result<StdResult<TaskHandle, String>> {
128133
}
129134

130135
pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {
131-
let lua = lua.clone();
136+
let lua = Rc::new(lua.clone());
132137
let mut task_rx = lua
133138
.remove_app_data::<UnboundedReceiver<Task>>()
134139
.expect("Failed to get task receiver");
135140

136141
lua.set_app_data(MaxBackgroundTasks(max_background_tasks));
137142

143+
let (shutdown_tx, shutdown_rx) = watch::channel(false);
144+
lua.set_app_data(ShutdownNotifier(shutdown_tx));
145+
138146
tokio::task::spawn_local(async move {
139147
while let Some(task) = task_rx.recv().await {
148+
let lua = lua.clone();
149+
let mut shutdown = shutdown_rx.clone();
140150
let join_handle = tokio::task::spawn_local(async move {
141151
let start = Instant::now();
142152
let _task_count_guard = tasks_counter_inc!();
153+
// Keep Lua instance alive while task is running
154+
let _lua_guard = lua;
143155
let task_future = task.handler.call_async::<Value>(());
144156

145157
let result = match task.timeout {
146-
Some(timeout) => ntex::time::timeout(timeout, task_future).await,
147-
None => Ok(task_future.await),
158+
Some(timeout) => tokio::select! {
159+
_ = shutdown.wait_for(|&x| x) => return Err("task scheduler shutdown".into_lua_err()),
160+
result = ntex::time::timeout(timeout, task_future) =>
161+
result.unwrap_or_else(|_| Err("task exceeded timeout".into_lua_err())),
162+
},
163+
None => tokio::select! {
164+
_ = shutdown.wait_for(|&x| x) => return Err("task scheduler shutdown".into_lua_err()),
165+
result = task_future => result,
166+
},
148167
};
149-
// Outer Result errors will always be timeouts
150-
let result = result
151-
.unwrap_or_else(|_| Err("task exceeded timeout".to_string()).into_lua_err());
152168

153169
// Record task metrics
154170
match task.name {
@@ -178,7 +194,9 @@ pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {
178194

179195
pub fn stop_task_scheduler(lua: &Lua) {
180196
lua.remove_app_data::<UnboundedSender<Task>>();
181-
lua.remove_app_data::<UnboundedReceiver<Task>>();
197+
198+
// Notify all tasks to stop
199+
_ = lua.app_data_ref::<ShutdownNotifier>().unwrap().0.send(true);
182200
}
183201

184202
pub fn create_module(lua: &Lua) -> Result<Table> {
@@ -192,14 +210,13 @@ pub fn create_module(lua: &Lua) -> Result<Table> {
192210

193211
#[cfg(test)]
194212
mod tests {
195-
use std::rc::Rc;
196213
use std::time::Duration;
197214

198215
use mlua::{chunk, Lua, Result};
199216

200217
#[ntex::test]
201218
async fn test_tasks() -> Result<()> {
202-
let lua = Rc::new(Lua::new());
219+
let lua = Lua::new();
203220

204221
lua.globals().set("tasks", super::create_module(&lua)?)?;
205222
lua.globals().set(
@@ -331,6 +348,8 @@ mod tests {
331348
.await
332349
.unwrap();
333350

351+
super::stop_task_scheduler(&lua);
352+
334353
Ok(())
335354
}
336355
}

0 commit comments

Comments
 (0)