1
+ use std:: rc:: Rc ;
1
2
use std:: result:: Result as StdResult ;
2
3
use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
3
4
4
5
use mlua:: {
5
6
AnyUserData , ExternalError , ExternalResult , Function , Lua , Result , Table , UserData , Value ,
6
7
} ;
7
8
use tokio:: sync:: mpsc:: { self , UnboundedReceiver , UnboundedSender } ;
8
- use tokio:: sync:: oneshot;
9
+ use tokio:: sync:: { oneshot, watch } ;
9
10
use tokio:: task:: JoinHandle ;
10
11
use tokio:: time:: { Duration , Instant } ;
11
12
use tracing:: warn;
@@ -30,8 +31,12 @@ struct TaskHandle {
30
31
join_handle_rx : Option < oneshot:: Receiver < TaskJoinHandle > > ,
31
32
}
32
33
34
+ #[ derive( Clone , Copy ) ]
33
35
struct MaxBackgroundTasks ( Option < u64 > ) ;
34
36
37
+ #[ derive( Clone ) ]
38
+ struct ShutdownNotifier ( watch:: Sender < bool > ) ;
39
+
35
40
// Global task identifier
36
41
static NEXT_TASK_ID : AtomicU64 = AtomicU64 :: new ( 1 ) ;
37
42
@@ -71,7 +76,7 @@ impl UserData for TaskHandle {
71
76
}
72
77
73
78
fn spawn_task ( lua : & Lua , arg : Value ) -> Result < StdResult < TaskHandle , String > > {
74
- let max_background_tasks = lua. app_data_ref :: < MaxBackgroundTasks > ( ) . unwrap ( ) ;
79
+ let max_background_tasks = * lua. app_data_ref :: < MaxBackgroundTasks > ( ) . unwrap ( ) ;
75
80
let current_tasks = tasks_counter_get ! ( ) ;
76
81
77
82
if let Some ( max_tasks) = max_background_tasks. 0 {
@@ -128,27 +133,38 @@ fn spawn_task(lua: &Lua, arg: Value) -> Result<StdResult<TaskHandle, String>> {
128
133
}
129
134
130
135
pub fn start_task_scheduler ( lua : & Lua , max_background_tasks : Option < u64 > ) {
131
- let lua = lua. clone ( ) ;
136
+ let lua = Rc :: new ( lua. clone ( ) ) ;
132
137
let mut task_rx = lua
133
138
. remove_app_data :: < UnboundedReceiver < Task > > ( )
134
139
. expect ( "Failed to get task receiver" ) ;
135
140
136
141
lua. set_app_data ( MaxBackgroundTasks ( max_background_tasks) ) ;
137
142
143
+ let ( shutdown_tx, shutdown_rx) = watch:: channel ( false ) ;
144
+ lua. set_app_data ( ShutdownNotifier ( shutdown_tx) ) ;
145
+
138
146
tokio:: task:: spawn_local ( async move {
139
147
while let Some ( task) = task_rx. recv ( ) . await {
148
+ let lua = lua. clone ( ) ;
149
+ let mut shutdown = shutdown_rx. clone ( ) ;
140
150
let join_handle = tokio:: task:: spawn_local ( async move {
141
151
let start = Instant :: now ( ) ;
142
152
let _task_count_guard = tasks_counter_inc ! ( ) ;
153
+ // Keep Lua instance alive while task is running
154
+ let _lua_guard = lua;
143
155
let task_future = task. handler . call_async :: < Value > ( ( ) ) ;
144
156
145
157
let result = match task. timeout {
146
- Some ( timeout) => ntex:: time:: timeout ( timeout, task_future) . await ,
147
- None => Ok ( task_future. await ) ,
158
+ Some ( timeout) => tokio:: select! {
159
+ _ = shutdown. wait_for( |& x| x) => return Err ( "task scheduler shutdown" . into_lua_err( ) ) ,
160
+ result = ntex:: time:: timeout( timeout, task_future) =>
161
+ result. unwrap_or_else( |_| Err ( "task exceeded timeout" . into_lua_err( ) ) ) ,
162
+ } ,
163
+ None => tokio:: select! {
164
+ _ = shutdown. wait_for( |& x| x) => return Err ( "task scheduler shutdown" . into_lua_err( ) ) ,
165
+ result = task_future => result,
166
+ } ,
148
167
} ;
149
- // Outer Result errors will always be timeouts
150
- let result = result
151
- . unwrap_or_else ( |_| Err ( "task exceeded timeout" . to_string ( ) ) . into_lua_err ( ) ) ;
152
168
153
169
// Record task metrics
154
170
match task. name {
@@ -178,7 +194,9 @@ pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {
178
194
179
195
pub fn stop_task_scheduler ( lua : & Lua ) {
180
196
lua. remove_app_data :: < UnboundedSender < Task > > ( ) ;
181
- lua. remove_app_data :: < UnboundedReceiver < Task > > ( ) ;
197
+
198
+ // Notify all tasks to stop
199
+ _ = lua. app_data_ref :: < ShutdownNotifier > ( ) . unwrap ( ) . 0 . send ( true ) ;
182
200
}
183
201
184
202
pub fn create_module ( lua : & Lua ) -> Result < Table > {
@@ -192,14 +210,13 @@ pub fn create_module(lua: &Lua) -> Result<Table> {
192
210
193
211
#[ cfg( test) ]
194
212
mod tests {
195
- use std:: rc:: Rc ;
196
213
use std:: time:: Duration ;
197
214
198
215
use mlua:: { chunk, Lua , Result } ;
199
216
200
217
#[ ntex:: test]
201
218
async fn test_tasks ( ) -> Result < ( ) > {
202
- let lua = Rc :: new ( Lua :: new ( ) ) ;
219
+ let lua = Lua :: new ( ) ;
203
220
204
221
lua. globals ( ) . set ( "tasks" , super :: create_module ( & lua) ?) ?;
205
222
lua. globals ( ) . set (
@@ -331,6 +348,8 @@ mod tests {
331
348
. await
332
349
. unwrap ( ) ;
333
350
351
+ super :: stop_task_scheduler ( & lua) ;
352
+
334
353
Ok ( ( ) )
335
354
}
336
355
}
0 commit comments