Skip to content

Commit 6d7b902

Browse files
authored
Fix hash join with sort push down (#13560)
* fix: join with sort push down * chore: insert some value * apply suggestion * recover handle_costom_pushdown change * apply suggestion * add more test * add partition
1 parent d3cfc45 commit 6d7b902

File tree

2 files changed

+228
-44
lines changed

2 files changed

+228
-44
lines changed

datafusion/core/src/physical_optimizer/sort_pushdown.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use crate::physical_plan::repartition::RepartitionExec;
2828
use crate::physical_plan::sorts::sort::SortExec;
2929
use crate::physical_plan::tree_node::PlanContext;
3030
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
31+
use arrow_schema::SchemaRef;
3132

3233
use datafusion_common::tree_node::{
3334
ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion,
@@ -38,6 +39,8 @@ use datafusion_physical_expr::expressions::Column;
3839
use datafusion_physical_expr::utils::collect_columns;
3940
use datafusion_physical_expr::PhysicalSortRequirement;
4041
use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
42+
use datafusion_physical_plan::joins::utils::ColumnIndex;
43+
use datafusion_physical_plan::joins::HashJoinExec;
4144

4245
/// This is a "data class" we use within the [`EnforceSorting`] rule to push
4346
/// down [`SortExec`] in the plan. In some cases, we can reduce the total
@@ -294,6 +297,8 @@ fn pushdown_requirement_to_children(
294297
.then(|| LexRequirement::new(parent_required.to_vec()));
295298
Ok(Some(vec![req]))
296299
}
300+
} else if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
301+
handle_hash_join(hash_join, parent_required)
297302
} else {
298303
handle_custom_pushdown(plan, parent_required, maintains_input_order)
299304
}
@@ -606,6 +611,102 @@ fn handle_custom_pushdown(
606611
}
607612
}
608613

614+
// For hash join we only maintain the input order for the right child
615+
// for join type: Inner, Right, RightSemi, RightAnti
616+
fn handle_hash_join(
617+
plan: &HashJoinExec,
618+
parent_required: &LexRequirement,
619+
) -> Result<Option<Vec<Option<LexRequirement>>>> {
620+
// If there's no requirement from the parent or the plan has no children
621+
// or the join type is not Inner, Right, RightSemi, RightAnti, return early
622+
if parent_required.is_empty() || !plan.maintains_input_order()[1] {
623+
return Ok(None);
624+
}
625+
626+
// Collect all unique column indices used in the parent-required sorting expression
627+
let all_indices: HashSet<usize> = parent_required
628+
.iter()
629+
.flat_map(|order| {
630+
collect_columns(&order.expr)
631+
.into_iter()
632+
.map(|col| col.index())
633+
.collect::<HashSet<_>>()
634+
})
635+
.collect();
636+
637+
let column_indices = build_join_column_index(plan);
638+
let projected_indices: Vec<_> = if let Some(projection) = &plan.projection {
639+
projection.iter().map(|&i| &column_indices[i]).collect()
640+
} else {
641+
column_indices.iter().collect()
642+
};
643+
let len_of_left_fields = projected_indices
644+
.iter()
645+
.filter(|ci| ci.side == JoinSide::Left)
646+
.count();
647+
648+
let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields);
649+
650+
// If all columns are from the right child, update the parent requirements
651+
if all_from_right_child {
652+
// Transform the parent-required expression for the child schema by adjusting columns
653+
let updated_parent_req = parent_required
654+
.iter()
655+
.map(|req| {
656+
let child_schema = plan.children()[1].schema();
657+
let updated_columns = Arc::clone(&req.expr)
658+
.transform_up(|expr| {
659+
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
660+
let index = projected_indices[col.index()].index;
661+
Ok(Transformed::yes(Arc::new(Column::new(
662+
child_schema.field(index).name(),
663+
index,
664+
))))
665+
} else {
666+
Ok(Transformed::no(expr))
667+
}
668+
})?
669+
.data;
670+
Ok(PhysicalSortRequirement::new(updated_columns, req.options))
671+
})
672+
.collect::<Result<Vec<_>>>()?;
673+
674+
// Populating with the updated requirements for children that maintain order
675+
Ok(Some(vec![
676+
None,
677+
Some(LexRequirement::new(updated_parent_req)),
678+
]))
679+
} else {
680+
Ok(None)
681+
}
682+
}
683+
684+
// this function is used to build the column index for the hash join
685+
// push down sort requirements to the right child
686+
fn build_join_column_index(plan: &HashJoinExec) -> Vec<ColumnIndex> {
687+
let map_fields = |schema: SchemaRef, side: JoinSide| {
688+
schema
689+
.fields()
690+
.iter()
691+
.enumerate()
692+
.map(|(index, _)| ColumnIndex { index, side })
693+
.collect::<Vec<_>>()
694+
};
695+
696+
match plan.join_type() {
697+
JoinType::Inner | JoinType::Right => {
698+
map_fields(plan.left().schema(), JoinSide::Left)
699+
.into_iter()
700+
.chain(map_fields(plan.right().schema(), JoinSide::Right))
701+
.collect::<Vec<_>>()
702+
}
703+
JoinType::RightSemi | JoinType::RightAnti => {
704+
map_fields(plan.right().schema(), JoinSide::Right)
705+
}
706+
_ => unreachable!("unexpected join type: {}", plan.join_type()),
707+
}
708+
}
709+
609710
/// Define the Requirements Compatibility
610711
#[derive(Debug)]
611712
enum RequirementsCompatibility {

datafusion/sqllogictest/test_files/joins.slt

Lines changed: 127 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,13 +2864,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I
28642864
----
28652865
physical_plan
28662866
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
2867-
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
2868-
03)----CoalesceBatchesExec: target_batch_size=2
2869-
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
2870-
05)--------CoalesceBatchesExec: target_batch_size=2
2871-
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
2872-
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
2873-
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
2867+
02)--CoalesceBatchesExec: target_batch_size=2
2868+
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
2869+
04)------CoalesceBatchesExec: target_batch_size=2
2870+
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
2871+
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
2872+
07)------------MemoryExec: partitions=1, partition_sizes=[1]
2873+
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
28742874
09)--------CoalesceBatchesExec: target_batch_size=2
28752875
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
28762876
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
@@ -2905,13 +2905,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI
29052905
----
29062906
physical_plan
29072907
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
2908-
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
2909-
03)----CoalesceBatchesExec: target_batch_size=2
2910-
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
2911-
05)--------CoalesceBatchesExec: target_batch_size=2
2912-
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
2913-
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
2914-
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
2908+
02)--CoalesceBatchesExec: target_batch_size=2
2909+
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
2910+
04)------CoalesceBatchesExec: target_batch_size=2
2911+
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
2912+
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
2913+
07)------------MemoryExec: partitions=1, partition_sizes=[1]
2914+
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
29152915
09)--------CoalesceBatchesExec: target_batch_size=2
29162916
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
29172917
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
@@ -2967,10 +2967,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I
29672967
----
29682968
physical_plan
29692969
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
2970-
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
2971-
03)----CoalesceBatchesExec: target_batch_size=2
2972-
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
2973-
05)--------MemoryExec: partitions=1, partition_sizes=[1]
2970+
02)--CoalesceBatchesExec: target_batch_size=2
2971+
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
2972+
04)------MemoryExec: partitions=1, partition_sizes=[1]
2973+
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
29742974
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
29752975
07)----------MemoryExec: partitions=1, partition_sizes=[1]
29762976

@@ -3003,10 +3003,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI
30033003
----
30043004
physical_plan
30053005
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
3006-
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
3007-
03)----CoalesceBatchesExec: target_batch_size=2
3008-
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
3009-
05)--------MemoryExec: partitions=1, partition_sizes=[1]
3006+
02)--CoalesceBatchesExec: target_batch_size=2
3007+
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
3008+
04)------MemoryExec: partitions=1, partition_sizes=[1]
3009+
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
30103010
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
30113011
07)----------MemoryExec: partitions=1, partition_sizes=[1]
30123012

@@ -3061,13 +3061,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER
30613061
----
30623062
physical_plan
30633063
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
3064-
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
3065-
03)----CoalesceBatchesExec: target_batch_size=2
3066-
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
3067-
05)--------CoalesceBatchesExec: target_batch_size=2
3068-
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
3069-
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
3070-
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
3064+
02)--CoalesceBatchesExec: target_batch_size=2
3065+
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
3066+
04)------CoalesceBatchesExec: target_batch_size=2
3067+
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
3068+
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
3069+
07)------------MemoryExec: partitions=1, partition_sizes=[1]
3070+
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
30713071
09)--------CoalesceBatchesExec: target_batch_size=2
30723072
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
30733073
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
@@ -3083,13 +3083,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH
30833083
----
30843084
physical_plan
30853085
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
3086-
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
3087-
03)----CoalesceBatchesExec: target_batch_size=2
3088-
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
3089-
05)--------CoalesceBatchesExec: target_batch_size=2
3090-
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
3091-
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
3092-
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
3086+
02)--CoalesceBatchesExec: target_batch_size=2
3087+
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
3088+
04)------CoalesceBatchesExec: target_batch_size=2
3089+
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
3090+
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
3091+
07)------------MemoryExec: partitions=1, partition_sizes=[1]
3092+
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
30933093
09)--------CoalesceBatchesExec: target_batch_size=2
30943094
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
30953095
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
@@ -3143,10 +3143,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER
31433143
----
31443144
physical_plan
31453145
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
3146-
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
3147-
03)----CoalesceBatchesExec: target_batch_size=2
3148-
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
3149-
05)--------MemoryExec: partitions=1, partition_sizes=[1]
3146+
02)--CoalesceBatchesExec: target_batch_size=2
3147+
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
3148+
04)------MemoryExec: partitions=1, partition_sizes=[1]
3149+
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
31503150
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
31513151
07)----------MemoryExec: partitions=1, partition_sizes=[1]
31523152

@@ -3160,10 +3160,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH
31603160
----
31613161
physical_plan
31623162
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
3163-
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
3164-
03)----CoalesceBatchesExec: target_batch_size=2
3165-
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
3166-
05)--------MemoryExec: partitions=1, partition_sizes=[1]
3163+
02)--CoalesceBatchesExec: target_batch_size=2
3164+
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
3165+
04)------MemoryExec: partitions=1, partition_sizes=[1]
3166+
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
31673167
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
31683168
07)----------MemoryExec: partitions=1, partition_sizes=[1]
31693169

@@ -4313,3 +4313,86 @@ physical_plan
43134313
04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)]
43144314
05)--------MemoryExec: partitions=1, partition_sizes=[1]
43154315
06)--------MemoryExec: partitions=1, partition_sizes=[1]
4316+
4317+
# Test hash join sort push down
4318+
# Issue: https://github.com/apache/datafusion/issues/13559
4319+
statement ok
4320+
CREATE TABLE test(a INT, b INT, c INT)
4321+
4322+
statement ok
4323+
insert into test values (1,2,3), (4,5,6), (null, 7, 8), (8, null, 9), (9, 10, null)
4324+
4325+
statement ok
4326+
set datafusion.execution.target_partitions = 2;
4327+
4328+
query TT
4329+
explain select * from test where a in (select a from test where b > 3) order by c desc nulls first;
4330+
----
4331+
logical_plan
4332+
01)Sort: test.c DESC NULLS FIRST
4333+
02)--LeftSemi Join: test.a = __correlated_sq_1.a
4334+
03)----TableScan: test projection=[a, b, c]
4335+
04)----SubqueryAlias: __correlated_sq_1
4336+
05)------Projection: test.a
4337+
06)--------Filter: test.b > Int32(3)
4338+
07)----------TableScan: test projection=[a, b]
4339+
physical_plan
4340+
01)SortPreservingMergeExec: [c@2 DESC]
4341+
02)--CoalesceBatchesExec: target_batch_size=3
4342+
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)]
4343+
04)------CoalesceBatchesExec: target_batch_size=3
4344+
05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
4345+
06)----------CoalesceBatchesExec: target_batch_size=3
4346+
07)------------FilterExec: b@1 > 3, projection=[a@0]
4347+
08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
4348+
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
4349+
10)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true]
4350+
11)--------CoalesceBatchesExec: target_batch_size=3
4351+
12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
4352+
13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
4353+
14)--------------MemoryExec: partitions=1, partition_sizes=[1]
4354+
4355+
query TT
4356+
explain select * from test where a in (select a from test where b > 3) order by c desc nulls last;
4357+
----
4358+
logical_plan
4359+
01)Sort: test.c DESC NULLS LAST
4360+
02)--LeftSemi Join: test.a = __correlated_sq_1.a
4361+
03)----TableScan: test projection=[a, b, c]
4362+
04)----SubqueryAlias: __correlated_sq_1
4363+
05)------Projection: test.a
4364+
06)--------Filter: test.b > Int32(3)
4365+
07)----------TableScan: test projection=[a, b]
4366+
physical_plan
4367+
01)SortPreservingMergeExec: [c@2 DESC NULLS LAST]
4368+
02)--CoalesceBatchesExec: target_batch_size=3
4369+
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)]
4370+
04)------CoalesceBatchesExec: target_batch_size=3
4371+
05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
4372+
06)----------CoalesceBatchesExec: target_batch_size=3
4373+
07)------------FilterExec: b@1 > 3, projection=[a@0]
4374+
08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
4375+
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
4376+
10)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true]
4377+
11)--------CoalesceBatchesExec: target_batch_size=3
4378+
12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
4379+
13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
4380+
14)--------------MemoryExec: partitions=1, partition_sizes=[1]
4381+
4382+
query III
4383+
select * from test where a in (select a from test where b > 3) order by c desc nulls first;
4384+
----
4385+
9 10 NULL
4386+
4 5 6
4387+
4388+
query III
4389+
select * from test where a in (select a from test where b > 3) order by c desc nulls last;
4390+
----
4391+
4 5 6
4392+
9 10 NULL
4393+
4394+
statement ok
4395+
DROP TABLE test
4396+
4397+
statement ok
4398+
set datafusion.execution.target_partitions = 1;

0 commit comments

Comments
 (0)