15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
- use std:: future:: Future ;
18
+ use std:: {
19
+ future:: Future ,
20
+ pin:: Pin ,
21
+ task:: { Context , Poll } ,
22
+ } ;
19
23
20
- use crate :: JoinSet ;
21
- use tokio:: task:: JoinError ;
24
+ use tokio:: task:: { JoinError , JoinHandle } ;
25
+
26
+ use crate :: trace_utils:: { trace_block, trace_future} ;
22
27
23
28
/// Helper that provides a simple API to spawn a single task and join it.
24
29
/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
30
+ /// Note that if the task was spawned with `spawn_blocking`, it will only be
31
+ /// aborted if it hasn't started yet.
25
32
///
26
- /// Technically, it's just a wrapper of `JoinSet` (with size=1) .
33
+ /// Technically, it's just a wrapper of a `JoinHandle` overriding drop .
27
34
#[ derive( Debug ) ]
28
35
pub struct SpawnedTask < R > {
29
- inner : JoinSet < R > ,
36
+ inner : JoinHandle < R > ,
30
37
}
31
38
32
39
impl < R : ' static > SpawnedTask < R > {
@@ -36,8 +43,9 @@ impl<R: 'static> SpawnedTask<R> {
36
43
T : Send + ' static ,
37
44
R : Send ,
38
45
{
39
- let mut inner = JoinSet :: new ( ) ;
40
- inner. spawn ( task) ;
46
+ // Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop
47
+ #[ allow( clippy:: disallowed_methods) ]
48
+ let inner = tokio:: task:: spawn ( trace_future ( task) ) ;
41
49
Self { inner }
42
50
}
43
51
@@ -47,22 +55,21 @@ impl<R: 'static> SpawnedTask<R> {
47
55
T : Send + ' static ,
48
56
R : Send ,
49
57
{
50
- let mut inner = JoinSet :: new ( ) ;
51
- inner. spawn_blocking ( task) ;
58
+ // Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop
59
+ #[ allow( clippy:: disallowed_methods) ]
60
+ let inner = tokio:: task:: spawn_blocking ( trace_block ( task) ) ;
52
61
Self { inner }
53
62
}
54
63
55
64
/// Joins the task, returning the result of join (`Result<R, JoinError>`).
56
- pub async fn join ( mut self ) -> Result < R , JoinError > {
57
- self . inner
58
- . join_next ( )
59
- . await
60
- . expect ( "`SpawnedTask` instance always contains exactly 1 task" )
65
+ /// Same as awaiting the spawned task, but left for backwards compatibility.
66
+ pub async fn join ( self ) -> Result < R , JoinError > {
67
+ self . await
61
68
}
62
69
63
70
/// Joins the task and unwinds the panic if it happens.
64
71
pub async fn join_unwind ( self ) -> Result < R , JoinError > {
65
- self . join ( ) . await . map_err ( |e| {
72
+ self . await . map_err ( |e| {
66
73
// `JoinError` can be caused either by panic or cancellation. We have to handle panics:
67
74
if e. is_panic ( ) {
68
75
std:: panic:: resume_unwind ( e. into_panic ( ) ) ;
@@ -77,17 +84,32 @@ impl<R: 'static> SpawnedTask<R> {
77
84
}
78
85
}
79
86
87
+ impl < R > Future for SpawnedTask < R > {
88
+ type Output = Result < R , JoinError > ;
89
+
90
+ fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
91
+ Pin :: new ( & mut self . inner ) . poll ( cx)
92
+ }
93
+ }
94
+
95
+ impl < R > Drop for SpawnedTask < R > {
96
+ fn drop ( & mut self ) {
97
+ self . inner . abort ( ) ;
98
+ }
99
+ }
100
+
80
101
#[ cfg( test) ]
81
102
mod tests {
82
103
use super :: * ;
83
104
84
105
use std:: future:: { pending, Pending } ;
85
106
86
- use tokio:: runtime:: Runtime ;
107
+ use tokio:: { runtime:: Runtime , sync :: oneshot } ;
87
108
88
109
#[ tokio:: test]
89
110
async fn runtime_shutdown ( ) {
90
111
let rt = Runtime :: new ( ) . unwrap ( ) ;
112
+ #[ allow( clippy:: async_yields_async) ]
91
113
let task = rt
92
114
. spawn ( async {
93
115
SpawnedTask :: spawn ( async {
@@ -119,4 +141,36 @@ mod tests {
119
141
. await
120
142
. ok ( ) ;
121
143
}
144
+
145
+ #[ tokio:: test]
146
+ async fn cancel_not_started_task ( ) {
147
+ let ( sender, receiver) = oneshot:: channel :: < i32 > ( ) ;
148
+ let task = SpawnedTask :: spawn ( async {
149
+ // Shouldn't be reached.
150
+ sender. send ( 42 ) . unwrap ( ) ;
151
+ } ) ;
152
+
153
+ drop ( task) ;
154
+
155
+ // If the task was cancelled, the sender was also dropped,
156
+ // and awaiting the receiver should result in an error.
157
+ assert ! ( receiver. await . is_err( ) ) ;
158
+ }
159
+
160
+ #[ tokio:: test]
161
+ async fn cancel_ongoing_task ( ) {
162
+ let ( sender, mut receiver) = tokio:: sync:: mpsc:: channel ( 1 ) ;
163
+ let task = SpawnedTask :: spawn ( async move {
164
+ sender. send ( 1 ) . await . unwrap ( ) ;
165
+ // This line will never be reached because the channel has a buffer
166
+ // of 1.
167
+ sender. send ( 2 ) . await . unwrap ( ) ;
168
+ } ) ;
169
+ // Let the task start.
170
+ assert_eq ! ( receiver. recv( ) . await . unwrap( ) , 1 ) ;
171
+ drop ( task) ;
172
+
173
+ // The sender was dropped so we receive `None`.
174
+ assert ! ( receiver. recv( ) . await . is_none( ) ) ;
175
+ }
122
176
}
0 commit comments