Skip to content

Commit 450c861

Browse files
wiedldalamb
andauthored
Refactor SortPushdown using the standard top-down visitor and using EquivalenceProperties (#14821)
* refactor: have sort pushdown use transform_down, and provide minor refactor in sort_pushdown_helper to make it more understandable * test: inconsequential single change in test * Use consistent variable naming * chore: update variable naming * refactor: only sync the plan children when required * fix: have orderings include constants which are heterogenius across partitions * Revert "fix: have orderings include constants which are heterogenius across partitions" This reverts commit 4775354. * test: temporary commit to demonstrate changes that only occur with no partition by (in window agg), and when aggregating on an unordered column * Revert "test: temporary commit to demonstrate changes that only occur with no partition by (in window agg), and when aggregating on an unordered column" This reverts commit 2ee747f. * chore: cleanup after merging main, for anticipated test change * chore: rename variable * refactor: added test cases for orthogonal sorting, and remove 1 unneeded conditional * chore: remove unneeded conditional and make a comment --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 13b731c commit 450c861

File tree

5 files changed

+199
-81
lines changed

5 files changed

+199
-81
lines changed

datafusion/core/tests/physical_optimizer/enforce_distribution.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2388,7 +2388,7 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> {
23882388
);
23892389

23902390
let expected = &[
2391-
"SortExec: expr=[c@2 ASC], preserve_partitioning=[true]",
2391+
"SortExec: expr=[c@2 ASC], preserve_partitioning=[false]",
23922392
// Since this projection is trivial, increasing parallelism is not beneficial
23932393
" ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]",
23942394
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet",

datafusion/core/tests/physical_optimizer/enforce_sorting.rs

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ use crate::physical_optimizer::test_utils::{
2121
aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec,
2222
coalesce_partitions_exec, create_test_schema, create_test_schema2,
2323
create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, limit_exec,
24-
local_limit_exec, memory_exec, parquet_exec, repartition_exec, sort_exec, sort_expr,
25-
sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec,
26-
sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered,
27-
union_exec, RequirementsTestExec,
24+
local_limit_exec, memory_exec, parquet_exec, repartition_exec, sort_exec,
25+
sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec,
26+
sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch,
27+
spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec,
2828
};
2929

3030
use arrow::compute::SortOptions;
@@ -2242,7 +2242,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> {
22422242
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet",
22432243
],
22442244
expected_plan: vec![
2245-
"SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]",
2245+
"SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]",
22462246
" WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]",
22472247
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet",
22482248
],
@@ -2259,7 +2259,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> {
22592259
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet",
22602260
],
22612261
expected_plan: vec![
2262-
"SortExec: expr=[non_nullable_col@1 DESC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]",
2262+
"SortExec: expr=[non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]",
22632263
" WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]",
22642264
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet",
22652265
],
@@ -2276,7 +2276,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> {
22762276
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet",
22772277
],
22782278
expected_plan: vec![
2279-
"SortExec: expr=[min@2 ASC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]",
2279+
"SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]",
22802280
" WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]",
22812281
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet",
22822282
],
@@ -2293,7 +2293,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> {
22932293
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet",
22942294
],
22952295
expected_plan: vec![
2296-
"SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]",
2296+
"SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]",
22972297
" WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]",
22982298
" DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet",
22992299
],
@@ -3346,3 +3346,89 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> {
33463346

33473347
Ok(())
33483348
}
3349+
3350+
#[test]
3351+
fn test_removes_unused_orthogonal_sort() -> Result<()> {
3352+
let schema = create_test_schema3()?;
3353+
let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)];
3354+
let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone());
3355+
3356+
let orthogonal_sort = sort_exec(vec![sort_expr("a", &schema)], unbounded_input);
3357+
let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); // same sort as data source
3358+
3359+
// Test scenario/input has an orthogonal sort:
3360+
let expected_input = [
3361+
"SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]",
3362+
" SortExec: expr=[a@0 ASC], preserve_partitioning=[false]",
3363+
" StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]"
3364+
];
3365+
assert_eq!(get_plan_string(&output_sort), expected_input,);
3366+
3367+
// Test: should remove orthogonal sort, and the uppermost (unneeded) sort:
3368+
let expected_optimized = [
3369+
"StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]"
3370+
];
3371+
assert_optimized!(expected_input, expected_optimized, output_sort, true);
3372+
3373+
Ok(())
3374+
}
3375+
3376+
#[test]
3377+
fn test_keeps_used_orthogonal_sort() -> Result<()> {
3378+
let schema = create_test_schema3()?;
3379+
let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)];
3380+
let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone());
3381+
3382+
let orthogonal_sort =
3383+
sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output
3384+
let output_sort = sort_exec(input_sort_exprs, orthogonal_sort);
3385+
3386+
// Test scenario/input has an orthogonal sort:
3387+
let expected_input = [
3388+
"SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]",
3389+
" SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]",
3390+
" StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]"
3391+
];
3392+
assert_eq!(get_plan_string(&output_sort), expected_input,);
3393+
3394+
// Test: should keep the orthogonal sort, since it modifies the output:
3395+
let expected_optimized = expected_input;
3396+
assert_optimized!(expected_input, expected_optimized, output_sort, true);
3397+
3398+
Ok(())
3399+
}
3400+
3401+
#[test]
3402+
fn test_handles_multiple_orthogonal_sorts() -> Result<()> {
3403+
let schema = create_test_schema3()?;
3404+
let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)];
3405+
let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone());
3406+
3407+
let orthogonal_sort_0 = sort_exec(vec![sort_expr("c", &schema)], unbounded_input); // has no fetch, so can be removed
3408+
let orthogonal_sort_1 =
3409+
sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output
3410+
let orthogonal_sort_2 = sort_exec(vec![sort_expr("c", &schema)], orthogonal_sort_1); // has no fetch, so can be removed
3411+
let orthogonal_sort_3 = sort_exec(vec![sort_expr("a", &schema)], orthogonal_sort_2); // has no fetch, so can be removed
3412+
let output_sort = sort_exec(input_sort_exprs, orthogonal_sort_3); // final sort
3413+
3414+
// Test scenario/input has an orthogonal sort:
3415+
let expected_input = [
3416+
"SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]",
3417+
" SortExec: expr=[a@0 ASC], preserve_partitioning=[false]",
3418+
" SortExec: expr=[c@2 ASC], preserve_partitioning=[false]",
3419+
" SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]",
3420+
" SortExec: expr=[c@2 ASC], preserve_partitioning=[false]",
3421+
" StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]",
3422+
];
3423+
assert_eq!(get_plan_string(&output_sort), expected_input,);
3424+
3425+
// Test: should keep only the needed orthogonal sort, and remove the unneeded ones:
3426+
let expected_optimized = [
3427+
"SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]",
3428+
" SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]",
3429+
" StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]",
3430+
];
3431+
assert_optimized!(expected_input, expected_optimized, output_sort, true);
3432+
3433+
Ok(())
3434+
}

datafusion/core/tests/physical_optimizer/test_utils.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,17 @@ pub fn coalesce_batches_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn Execution
295295
pub fn sort_exec(
296296
sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>,
297297
input: Arc<dyn ExecutionPlan>,
298+
) -> Arc<dyn ExecutionPlan> {
299+
sort_exec_with_fetch(sort_exprs, None, input)
300+
}
301+
302+
pub fn sort_exec_with_fetch(
303+
sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>,
304+
fetch: Option<usize>,
305+
input: Arc<dyn ExecutionPlan>,
298306
) -> Arc<dyn ExecutionPlan> {
299307
let sort_exprs = sort_exprs.into_iter().collect();
300-
Arc::new(SortExec::new(sort_exprs, input))
308+
Arc::new(SortExec::new(sort_exprs, input).with_fetch(fetch))
301309
}
302310

303311
/// A test [`ExecutionPlan`] whose requirements can be configured.

0 commit comments

Comments
 (0)