Skip to content

Commit 02a470f

Browse files
authored
Replace AbortOnDrop / AbortDropOnMany with tokio JoinSet (#6750)
* Use JoinSet in MemTable * Fix error handling * Refactor AbortOnDropSingle in csv physical plan * Fix csv write physical plan error propagation * Refactor json write physical plan to use JoinSet * Refactor parquet write physical plan to use JoinSet * Refactor collect_partitioned to use JoinSet * Refactor pull_from_input method to make it easier to read * Fix typo
1 parent 07a721f commit 02a470f

File tree

6 files changed

+127
-93
lines changed

6 files changed

+127
-93
lines changed

datafusion/core/src/datasource/memory.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ use async_trait::async_trait;
2929
use datafusion_common::SchemaExt;
3030
use datafusion_execution::TaskContext;
3131
use tokio::sync::RwLock;
32+
use tokio::task::JoinSet;
3233

3334
use crate::datasource::{TableProvider, TableType};
3435
use crate::error::{DataFusionError, Result};
3536
use crate::execution::context::SessionState;
3637
use crate::logical_expr::Expr;
37-
use crate::physical_plan::common::AbortOnDropSingle;
3838
use crate::physical_plan::insert::{DataSink, InsertExec};
3939
use crate::physical_plan::memory::MemoryExec;
4040
use crate::physical_plan::{common, SendableRecordBatchStream};
@@ -89,26 +89,31 @@ impl MemTable {
8989
let exec = t.scan(state, None, &[], None).await?;
9090
let partition_count = exec.output_partitioning().partition_count();
9191

92-
let tasks = (0..partition_count)
93-
.map(|part_i| {
94-
let task = state.task_ctx();
95-
let exec = exec.clone();
96-
let task = tokio::spawn(async move {
97-
let stream = exec.execute(part_i, task)?;
98-
common::collect(stream).await
99-
});
100-
101-
AbortOnDropSingle::new(task)
102-
})
103-
// this collect *is needed* so that the join below can
104-
// switch between tasks
105-
.collect::<Vec<_>>();
92+
let mut join_set = JoinSet::new();
93+
94+
for part_idx in 0..partition_count {
95+
let task = state.task_ctx();
96+
let exec = exec.clone();
97+
join_set.spawn(async move {
98+
let stream = exec.execute(part_idx, task)?;
99+
common::collect(stream).await
100+
});
101+
}
106102

107103
let mut data: Vec<Vec<RecordBatch>> =
108104
Vec::with_capacity(exec.output_partitioning().partition_count());
109105

110-
for result in futures::future::join_all(tasks).await {
111-
data.push(result.map_err(|e| DataFusionError::External(Box::new(e)))??)
106+
while let Some(result) = join_set.join_next().await {
107+
match result {
108+
Ok(res) => data.push(res?),
109+
Err(e) => {
110+
if e.is_panic() {
111+
std::panic::resume_unwind(e.into_panic());
112+
} else {
113+
unreachable!();
114+
}
115+
}
116+
}
112117
}
113118

114119
let exec = MemoryExec::try_new(&data, schema.clone(), None)?;

datafusion/core/src/datasource/physical_plan/csv.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ use crate::datasource::physical_plan::file_stream::{
2323
};
2424
use crate::datasource::physical_plan::FileMeta;
2525
use crate::error::{DataFusionError, Result};
26-
use crate::physical_plan::common::AbortOnDropSingle;
2726
use crate::physical_plan::expressions::PhysicalSortExpr;
2827
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
2928
use crate::physical_plan::{
@@ -46,7 +45,7 @@ use std::fs;
4645
use std::path::Path;
4746
use std::sync::Arc;
4847
use std::task::Poll;
49-
use tokio::task::{self, JoinHandle};
48+
use tokio::task::JoinSet;
5049

5150
/// Execution plan for scanning a CSV file
5251
#[derive(Debug, Clone)]
@@ -331,7 +330,7 @@ pub async fn plan_to_csv(
331330
)));
332331
}
333332

334-
let mut tasks = vec![];
333+
let mut join_set = JoinSet::new();
335334
for i in 0..plan.output_partitioning().partition_count() {
336335
let plan = plan.clone();
337336
let filename = format!("part-{i}.csv");
@@ -340,22 +339,29 @@ pub async fn plan_to_csv(
340339
let mut writer = csv::Writer::new(file);
341340
let stream = plan.execute(i, task_ctx.clone())?;
342341

343-
let handle: JoinHandle<Result<()>> = task::spawn(async move {
344-
stream
342+
join_set.spawn(async move {
343+
let result: Result<()> = stream
345344
.map(|batch| writer.write(&batch?))
346345
.try_collect()
347346
.await
348-
.map_err(DataFusionError::from)
347+
.map_err(DataFusionError::from);
348+
result
349349
});
350-
tasks.push(AbortOnDropSingle::new(handle));
351350
}
352351

353-
futures::future::join_all(tasks)
354-
.await
355-
.into_iter()
356-
.try_for_each(|result| {
357-
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
358-
})?;
352+
while let Some(result) = join_set.join_next().await {
353+
match result {
354+
Ok(res) => res?, // propagate DataFusion error
355+
Err(e) => {
356+
if e.is_panic() {
357+
std::panic::resume_unwind(e.into_panic());
358+
} else {
359+
unreachable!();
360+
}
361+
}
362+
}
363+
}
364+
359365
Ok(())
360366
}
361367

datafusion/core/src/datasource/physical_plan/json.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ use crate::datasource::physical_plan::file_stream::{
2222
};
2323
use crate::datasource::physical_plan::FileMeta;
2424
use crate::error::{DataFusionError, Result};
25-
use crate::physical_plan::common::AbortOnDropSingle;
2625
use crate::physical_plan::expressions::PhysicalSortExpr;
2726
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
2827
use crate::physical_plan::{
@@ -44,7 +43,7 @@ use std::io::BufReader;
4443
use std::path::Path;
4544
use std::sync::Arc;
4645
use std::task::Poll;
47-
use tokio::task::{self, JoinHandle};
46+
use tokio::task::JoinSet;
4847

4948
use super::FileScanConfig;
5049

@@ -266,30 +265,37 @@ pub async fn plan_to_json(
266265
)));
267266
}
268267

269-
let mut tasks = vec![];
268+
let mut join_set = JoinSet::new();
270269
for i in 0..plan.output_partitioning().partition_count() {
271270
let plan = plan.clone();
272271
let filename = format!("part-{i}.json");
273272
let path = fs_path.join(filename);
274273
let file = fs::File::create(path)?;
275274
let mut writer = json::LineDelimitedWriter::new(file);
276275
let stream = plan.execute(i, task_ctx.clone())?;
277-
let handle: JoinHandle<Result<()>> = task::spawn(async move {
278-
stream
276+
join_set.spawn(async move {
277+
let result: Result<()> = stream
279278
.map(|batch| writer.write(&batch?))
280279
.try_collect()
281280
.await
282-
.map_err(DataFusionError::from)
281+
.map_err(DataFusionError::from);
282+
result
283283
});
284-
tasks.push(AbortOnDropSingle::new(handle));
285284
}
286285

287-
futures::future::join_all(tasks)
288-
.await
289-
.into_iter()
290-
.try_for_each(|result| {
291-
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
292-
})?;
286+
while let Some(result) = join_set.join_next().await {
287+
match result {
288+
Ok(res) => res?, // propagate DataFusion error
289+
Err(e) => {
290+
if e.is_panic() {
291+
std::panic::resume_unwind(e.into_panic());
292+
} else {
293+
unreachable!();
294+
}
295+
}
296+
}
297+
}
298+
293299
Ok(())
294300
}
295301

datafusion/core/src/datasource/physical_plan/parquet.rs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ use crate::{
3131
execution::context::TaskContext,
3232
physical_optimizer::pruning::PruningPredicate,
3333
physical_plan::{
34-
common::AbortOnDropSingle,
3534
metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet},
3635
ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan,
3736
Partitioning, SendableRecordBatchStream, Statistics,
@@ -64,6 +63,7 @@ use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMas
6463
use parquet::basic::{ConvertedType, LogicalType};
6564
use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties};
6665
use parquet::schema::types::ColumnDescriptor;
66+
use tokio::task::JoinSet;
6767

6868
mod metrics;
6969
pub mod page_filter;
@@ -701,7 +701,7 @@ pub async fn plan_to_parquet(
701701
)));
702702
}
703703

704-
let mut tasks = vec![];
704+
let mut join_set = JoinSet::new();
705705
for i in 0..plan.output_partitioning().partition_count() {
706706
let plan = plan.clone();
707707
let filename = format!("part-{i}.parquet");
@@ -710,27 +710,30 @@ pub async fn plan_to_parquet(
710710
let mut writer =
711711
ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?;
712712
let stream = plan.execute(i, task_ctx.clone())?;
713-
let handle: tokio::task::JoinHandle<Result<()>> =
714-
tokio::task::spawn(async move {
715-
stream
716-
.map(|batch| {
717-
writer.write(&batch?).map_err(DataFusionError::ParquetError)
718-
})
719-
.try_collect()
720-
.await
721-
.map_err(DataFusionError::from)?;
713+
join_set.spawn(async move {
714+
stream
715+
.map(|batch| writer.write(&batch?).map_err(DataFusionError::ParquetError))
716+
.try_collect()
717+
.await
718+
.map_err(DataFusionError::from)?;
719+
720+
writer.close().map_err(DataFusionError::from).map(|_| ())
721+
});
722+
}
722723

723-
writer.close().map_err(DataFusionError::from).map(|_| ())
724-
});
725-
tasks.push(AbortOnDropSingle::new(handle));
724+
while let Some(result) = join_set.join_next().await {
725+
match result {
726+
Ok(res) => res?,
727+
Err(e) => {
728+
if e.is_panic() {
729+
std::panic::resume_unwind(e.into_panic());
730+
} else {
731+
unreachable!();
732+
}
733+
}
734+
}
726735
}
727736

728-
futures::future::join_all(tasks)
729-
.await
730-
.into_iter()
731-
.try_for_each(|result| {
732-
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
733-
})?;
734737
Ok(())
735738
}
736739

datafusion/core/src/physical_plan/mod.rs

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub use display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay};
3838
use futures::stream::{Stream, TryStreamExt};
3939
use std::fmt;
4040
use std::fmt::Debug;
41+
use tokio::task::JoinSet;
4142

4243
use datafusion_common::tree_node::Transformed;
4344
use datafusion_common::DataFusionError;
@@ -445,20 +446,37 @@ pub async fn collect_partitioned(
445446
) -> Result<Vec<Vec<RecordBatch>>> {
446447
let streams = execute_stream_partitioned(plan, context)?;
447448

449+
let mut join_set = JoinSet::new();
448450
// Execute the plan and collect the results into batches.
449-
let handles = streams
450-
.into_iter()
451-
.enumerate()
452-
.map(|(idx, stream)| async move {
453-
let handle = tokio::task::spawn(stream.try_collect());
454-
AbortOnDropSingle::new(handle).await.map_err(|e| {
455-
DataFusionError::Execution(format!(
456-
"collect_partitioned partition {idx} panicked: {e}"
457-
))
458-
})?
451+
streams.into_iter().enumerate().for_each(|(idx, stream)| {
452+
join_set.spawn(async move {
453+
let result: Result<Vec<RecordBatch>> = stream.try_collect().await;
454+
(idx, result)
459455
});
456+
});
457+
458+
let mut batches = vec![];
459+
// Note that currently this doesn't identify the thread that panicked
460+
//
461+
// TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id
462+
// once it is stable
463+
while let Some(result) = join_set.join_next().await {
464+
match result {
465+
Ok((idx, res)) => batches.push((idx, res?)),
466+
Err(e) => {
467+
if e.is_panic() {
468+
std::panic::resume_unwind(e.into_panic());
469+
} else {
470+
unreachable!();
471+
}
472+
}
473+
}
474+
}
475+
476+
batches.sort_by_key(|(idx, _)| *idx);
477+
let batches = batches.into_iter().map(|(_, batch)| batch).collect();
460478

461-
futures::future::try_join_all(handles).await
479+
Ok(batches)
462480
}
463481

464482
/// Execute the [ExecutionPlan] and return a vec with one stream per output partition
@@ -713,7 +731,6 @@ pub mod unnest;
713731
pub mod values;
714732
pub mod windows;
715733

716-
use crate::physical_plan::common::AbortOnDropSingle;
717734
use crate::physical_plan::repartition::RepartitionExec;
718735
use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
719736
use datafusion_execution::TaskContext;

0 commit comments

Comments
 (0)