Skip to content

Commit b47dbba

Browse files
committed
Implement Future for SpawnedTask.
It allows polling a SpawnedTask, instead of just joining it. The implementation is changed from `JoinSet` to a `JoinHandle` to simplify the code, as `JoinSet` doesn't provide any additional benefits.
1 parent 51cc046 commit b47dbba

File tree

1 file changed

+70
-16
lines changed

1 file changed

+70
-16
lines changed

datafusion/common-runtime/src/common.rs

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,25 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::future::Future;
18+
use std::{
19+
future::Future,
20+
pin::Pin,
21+
task::{Context, Poll},
22+
};
1923

20-
use crate::JoinSet;
21-
use tokio::task::JoinError;
24+
use tokio::task::{JoinError, JoinHandle};
25+
26+
use crate::trace_utils::{trace_block, trace_future};
2227

2328
/// Helper that provides a simple API to spawn a single task and join it.
2429
/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
30+
/// Note that if the task was spawned with `spawn_blocking`, it will only be
31+
/// aborted if it hasn't started yet.
2532
///
26-
/// Technically, it's just a wrapper of `JoinSet` (with size=1).
33+
/// Technically, it's just a wrapper of a `JoinHandle` overriding drop.
2734
#[derive(Debug)]
2835
pub struct SpawnedTask<R> {
29-
inner: JoinSet<R>,
36+
inner: JoinHandle<R>,
3037
}
3138

3239
impl<R: 'static> SpawnedTask<R> {
@@ -36,8 +43,9 @@ impl<R: 'static> SpawnedTask<R> {
3643
T: Send + 'static,
3744
R: Send,
3845
{
39-
let mut inner = JoinSet::new();
40-
inner.spawn(task);
46+
// Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop
47+
#[allow(clippy::disallowed_methods)]
48+
let inner = tokio::task::spawn(trace_future(task));
4149
Self { inner }
4250
}
4351

@@ -47,22 +55,21 @@ impl<R: 'static> SpawnedTask<R> {
4755
T: Send + 'static,
4856
R: Send,
4957
{
50-
let mut inner = JoinSet::new();
51-
inner.spawn_blocking(task);
58+
// Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop
59+
#[allow(clippy::disallowed_methods)]
60+
let inner = tokio::task::spawn_blocking(trace_block(task));
5261
Self { inner }
5362
}
5463

5564
/// Joins the task, returning the result of join (`Result<R, JoinError>`).
56-
pub async fn join(mut self) -> Result<R, JoinError> {
57-
self.inner
58-
.join_next()
59-
.await
60-
.expect("`SpawnedTask` instance always contains exactly 1 task")
65+
/// Same as awaiting the spawned task, but left for backwards compatibility.
66+
pub async fn join(self) -> Result<R, JoinError> {
67+
self.await
6168
}
6269

6370
/// Joins the task and unwinds the panic if it happens.
6471
pub async fn join_unwind(self) -> Result<R, JoinError> {
65-
self.join().await.map_err(|e| {
72+
self.await.map_err(|e| {
6673
// `JoinError` can be caused either by panic or cancellation. We have to handle panics:
6774
if e.is_panic() {
6875
std::panic::resume_unwind(e.into_panic());
@@ -77,17 +84,32 @@ impl<R: 'static> SpawnedTask<R> {
7784
}
7885
}
7986

87+
impl<R> Future for SpawnedTask<R> {
88+
type Output = Result<R, JoinError>;
89+
90+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
91+
Pin::new(&mut self.inner).poll(cx)
92+
}
93+
}
94+
95+
impl<R> Drop for SpawnedTask<R> {
96+
fn drop(&mut self) {
97+
self.inner.abort();
98+
}
99+
}
100+
80101
#[cfg(test)]
81102
mod tests {
82103
use super::*;
83104

84105
use std::future::{pending, Pending};
85106

86-
use tokio::runtime::Runtime;
107+
use tokio::{runtime::Runtime, sync::oneshot};
87108

88109
#[tokio::test]
89110
async fn runtime_shutdown() {
90111
let rt = Runtime::new().unwrap();
112+
#[allow(clippy::async_yields_async)]
91113
let task = rt
92114
.spawn(async {
93115
SpawnedTask::spawn(async {
@@ -119,4 +141,36 @@ mod tests {
119141
.await
120142
.ok();
121143
}
144+
145+
#[tokio::test]
146+
async fn cancel_not_started_task() {
147+
let (sender, receiver) = oneshot::channel::<i32>();
148+
let task = SpawnedTask::spawn(async {
149+
// Shouldn't be reached.
150+
sender.send(42).unwrap();
151+
});
152+
153+
drop(task);
154+
155+
// If the task was cancelled, the sender was also dropped,
156+
// and awaiting the receiver should result in an error.
157+
assert!(receiver.await.is_err());
158+
}
159+
160+
#[tokio::test]
161+
async fn cancel_ongoing_task() {
162+
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
163+
let task = SpawnedTask::spawn(async move {
164+
sender.send(1).await.unwrap();
165+
// This line will never be reached because the channel has a buffer
166+
// of 1.
167+
sender.send(2).await.unwrap();
168+
});
169+
// Let the task start.
170+
assert_eq!(receiver.recv().await.unwrap(), 1);
171+
drop(task);
172+
173+
// The sender was dropped so we receive `None`.
174+
assert!(receiver.recv().await.is_none());
175+
}
122176
}

0 commit comments

Comments
 (0)