Skip to content

Commit d594e62

Browse files
viiryaalamb
andauthored
Relax join keys constraint from Column to any physical expression for physical join operators (#8991)
* Relex SortMergeJoin join keys * More * More * More * More * Fix clippy * Fix more clippy * More * More * Fix * Fix * Use collect_columns --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 92104a5 commit d594e62

File tree

18 files changed

+691
-511
lines changed

18 files changed

+691
-511
lines changed

datafusion/core/src/physical_optimizer/enforce_distribution.rs

Lines changed: 170 additions & 121 deletions
Large diffs are not rendered by default.

datafusion/core/src/physical_optimizer/enforce_sorting.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,8 +985,8 @@ mod tests {
985985
let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs);
986986

987987
let on = vec![(
988-
Column::new_with_schema("col_a", &left_schema)?,
989-
Column::new_with_schema("c", &right_schema)?,
988+
Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _,
989+
Arc::new(Column::new_with_schema("c", &right_schema)?) as _,
990990
)];
991991
let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?;
992992
let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join);
@@ -1639,8 +1639,9 @@ mod tests {
16391639

16401640
// Join on (nullable_col == col_a)
16411641
let join_on = vec![(
1642-
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
1643-
Column::new_with_schema("col_a", &right.schema()).unwrap(),
1642+
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
1643+
as _,
1644+
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
16441645
)];
16451646

16461647
let join_types = vec![
@@ -1711,8 +1712,9 @@ mod tests {
17111712

17121713
// Join on (nullable_col == col_a)
17131714
let join_on = vec![(
1714-
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
1715-
Column::new_with_schema("col_a", &right.schema()).unwrap(),
1715+
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
1716+
as _,
1717+
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
17161718
)];
17171719

17181720
let join_types = vec![
@@ -1785,8 +1787,9 @@ mod tests {
17851787

17861788
// Join on (nullable_col == col_a)
17871789
let join_on = vec![(
1788-
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
1789-
Column::new_with_schema("col_a", &right.schema()).unwrap(),
1790+
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
1791+
as _,
1792+
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
17901793
)];
17911794

17921795
let join = sort_merge_join_exec(left, right, &join_on, &JoinType::Inner);

datafusion/core/src/physical_optimizer/join_selection.rs

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ mod tests_statistical {
690690
use arrow::datatypes::{DataType, Field, Schema};
691691
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
692692
use datafusion_physical_expr::expressions::Column;
693-
use datafusion_physical_expr::PhysicalExpr;
693+
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};
694694

695695
/// Return statistcs for empty table
696696
fn empty_statistics() -> Statistics {
@@ -860,8 +860,10 @@ mod tests_statistical {
860860
Arc::clone(&big),
861861
Arc::clone(&small),
862862
vec![(
863-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
864-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
863+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
864+
Arc::new(
865+
Column::new_with_schema("small_col", &small.schema()).unwrap(),
866+
),
865867
)],
866868
None,
867869
&JoinType::Left,
@@ -914,8 +916,10 @@ mod tests_statistical {
914916
Arc::clone(&small),
915917
Arc::clone(&big),
916918
vec![(
917-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
918-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
919+
Arc::new(
920+
Column::new_with_schema("small_col", &small.schema()).unwrap(),
921+
),
922+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
919923
)],
920924
None,
921925
&JoinType::Left,
@@ -970,8 +974,13 @@ mod tests_statistical {
970974
Arc::clone(&big),
971975
Arc::clone(&small),
972976
vec![(
973-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
974-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
977+
Arc::new(
978+
Column::new_with_schema("big_col", &big.schema()).unwrap(),
979+
),
980+
Arc::new(
981+
Column::new_with_schema("small_col", &small.schema())
982+
.unwrap(),
983+
),
975984
)],
976985
None,
977986
&join_type,
@@ -1040,8 +1049,8 @@ mod tests_statistical {
10401049
Arc::clone(&big),
10411050
Arc::clone(&small),
10421051
vec![(
1043-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
1044-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
1052+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
1053+
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()),
10451054
)],
10461055
None,
10471056
&JoinType::Inner,
@@ -1056,8 +1065,10 @@ mod tests_statistical {
10561065
Arc::clone(&medium),
10571066
Arc::new(child_join),
10581067
vec![(
1059-
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
1060-
Column::new_with_schema("small_col", &child_schema).unwrap(),
1068+
Arc::new(
1069+
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
1070+
),
1071+
Arc::new(Column::new_with_schema("small_col", &child_schema).unwrap()),
10611072
)],
10621073
None,
10631074
&JoinType::Left,
@@ -1094,8 +1105,10 @@ mod tests_statistical {
10941105
Arc::clone(&small),
10951106
Arc::clone(&big),
10961107
vec![(
1097-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
1098-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
1108+
Arc::new(
1109+
Column::new_with_schema("small_col", &small.schema()).unwrap(),
1110+
),
1111+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
10991112
)],
11001113
None,
11011114
&JoinType::Inner,
@@ -1178,8 +1191,8 @@ mod tests_statistical {
11781191
));
11791192

11801193
let join_on = vec![(
1181-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
1182-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
1194+
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
1195+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
11831196
)];
11841197
check_join_partition_mode(
11851198
small.clone(),
@@ -1190,8 +1203,8 @@ mod tests_statistical {
11901203
);
11911204

11921205
let join_on = vec![(
1193-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
1194-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
1206+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
1207+
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
11951208
)];
11961209
check_join_partition_mode(
11971210
big.clone(),
@@ -1202,8 +1215,8 @@ mod tests_statistical {
12021215
);
12031216

12041217
let join_on = vec![(
1205-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
1206-
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
1218+
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
1219+
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
12071220
)];
12081221
check_join_partition_mode(
12091222
small.clone(),
@@ -1214,8 +1227,8 @@ mod tests_statistical {
12141227
);
12151228

12161229
let join_on = vec![(
1217-
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
1218-
Column::new_with_schema("small_col", &small.schema()).unwrap(),
1230+
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
1231+
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
12191232
)];
12201233
check_join_partition_mode(
12211234
empty.clone(),
@@ -1244,8 +1257,9 @@ mod tests_statistical {
12441257
));
12451258

12461259
let join_on = vec![(
1247-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
1248-
Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
1260+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
1261+
Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap())
1262+
as _,
12491263
)];
12501264
check_join_partition_mode(
12511265
big.clone(),
@@ -1256,8 +1270,9 @@ mod tests_statistical {
12561270
);
12571271

12581272
let join_on = vec![(
1259-
Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
1260-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
1273+
Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap())
1274+
as _,
1275+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
12611276
)];
12621277
check_join_partition_mode(
12631278
bigger.clone(),
@@ -1268,8 +1283,8 @@ mod tests_statistical {
12681283
);
12691284

12701285
let join_on = vec![(
1271-
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
1272-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
1286+
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
1287+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
12731288
)];
12741289
check_join_partition_mode(
12751290
empty.clone(),
@@ -1280,16 +1295,16 @@ mod tests_statistical {
12801295
);
12811296

12821297
let join_on = vec![(
1283-
Column::new_with_schema("big_col", &big.schema()).unwrap(),
1284-
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
1298+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
1299+
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
12851300
)];
12861301
check_join_partition_mode(big, empty, join_on, false, PartitionMode::Partitioned);
12871302
}
12881303

12891304
fn check_join_partition_mode(
12901305
left: Arc<StatisticsExec>,
12911306
right: Arc<StatisticsExec>,
1292-
on: Vec<(Column, Column)>,
1307+
on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
12931308
is_swapped: bool,
12941309
expected_mode: PartitionMode,
12951310
) {
@@ -1748,8 +1763,8 @@ mod hash_join_tests {
17481763
Arc::clone(&left_exec),
17491764
Arc::clone(&right_exec),
17501765
vec![(
1751-
Column::new_with_schema("a", &left_exec.schema())?,
1752-
Column::new_with_schema("b", &right_exec.schema())?,
1766+
Arc::new(Column::new_with_schema("a", &left_exec.schema())?),
1767+
Arc::new(Column::new_with_schema("b", &right_exec.schema())?),
17531768
)],
17541769
None,
17551770
&t.initial_join_type,

datafusion/core/src/physical_optimizer/projection_pushdown.rs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ use crate::physical_plan::{Distribution, ExecutionPlan};
4444
use arrow_schema::SchemaRef;
4545
use datafusion_common::config::ConfigOptions;
4646
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
47-
use datafusion_common::JoinSide;
47+
use datafusion_common::{DataFusionError, JoinSide};
4848
use datafusion_physical_expr::expressions::{Column, Literal};
4949
use datafusion_physical_expr::{
50-
Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
50+
Partitioning, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
51+
PhysicalSortRequirement,
5152
};
5253
use datafusion_physical_plan::streaming::StreamingTableExec;
5354
use datafusion_physical_plan::union::UnionExec;
@@ -1000,8 +1001,8 @@ fn join_table_borders(
10001001
fn update_join_on(
10011002
proj_left_exprs: &[(Column, String)],
10021003
proj_right_exprs: &[(Column, String)],
1003-
hash_join_on: &[(Column, Column)],
1004-
) -> Option<Vec<(Column, Column)>> {
1004+
hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
1005+
) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
10051006
// TODO: Clippy wants the "map" call removed, but doing so generates
10061007
// a compilation error. Remove the clippy directive once this
10071008
// issue is fixed.
@@ -1024,17 +1025,41 @@ fn update_join_on(
10241025
/// operation based on a set of equi-join conditions (`hash_join_on`) and a
10251026
/// list of projection expressions (`projection_exprs`).
10261027
fn new_columns_for_join_on(
1027-
hash_join_on: &[&Column],
1028+
hash_join_on: &[&PhysicalExprRef],
10281029
projection_exprs: &[(Column, String)],
1029-
) -> Option<Vec<Column>> {
1030+
) -> Option<Vec<PhysicalExprRef>> {
10301031
let new_columns = hash_join_on
10311032
.iter()
10321033
.filter_map(|on| {
1033-
projection_exprs
1034-
.iter()
1035-
.enumerate()
1036-
.find(|(_, (proj_column, _))| on.name() == proj_column.name())
1037-
.map(|(index, (_, alias))| Column::new(alias, index))
1034+
// Rewrite all columns in `on`
1035+
(*on)
1036+
.clone()
1037+
.transform(&|expr| {
1038+
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
1039+
// Find the column in the projection expressions
1040+
let new_column = projection_exprs
1041+
.iter()
1042+
.enumerate()
1043+
.find(|(_, (proj_column, _))| {
1044+
column.name() == proj_column.name()
1045+
})
1046+
.map(|(index, (_, alias))| Column::new(alias, index));
1047+
if let Some(new_column) = new_column {
1048+
Ok(Transformed::Yes(Arc::new(new_column)))
1049+
} else {
1050+
// If the column is not found in the projection expressions,
1051+
// it means that the column is not projected. In this case,
1052+
// we cannot push the projection down.
1053+
Err(DataFusionError::Internal(format!(
1054+
"Column {:?} not found in projection expressions",
1055+
column
1056+
)))
1057+
}
1058+
} else {
1059+
Ok(Transformed::No(expr))
1060+
}
1061+
})
1062+
.ok()
10381063
})
10391064
.collect::<Vec<_>>();
10401065
(new_columns.len() == hash_join_on.len()).then_some(new_columns)
@@ -2018,7 +2043,7 @@ mod tests {
20182043
let join: Arc<dyn ExecutionPlan> = Arc::new(SymmetricHashJoinExec::try_new(
20192044
left_csv,
20202045
right_csv,
2021-
vec![(Column::new("b", 1), Column::new("c", 2))],
2046+
vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))],
20222047
// b_left-(1+a_right)<=a_right+c_left
20232048
Some(JoinFilter::new(
20242049
Arc::new(BinaryExpr::new(

datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,7 @@ mod tests {
14401440
HashJoinExec::try_new(
14411441
left,
14421442
right,
1443-
vec![(left_col.clone(), right_col.clone())],
1443+
vec![(Arc::new(left_col.clone()), Arc::new(right_col.clone()))],
14441444
None,
14451445
&JoinType::Inner,
14461446
PartitionMode::Partitioned,

datafusion/core/src/physical_planner.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,15 +1036,21 @@ impl DefaultPhysicalPlanner {
10361036
let [physical_left, physical_right]: [Arc<dyn ExecutionPlan>; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?;
10371037
let left_df_schema = left.schema();
10381038
let right_df_schema = right.schema();
1039+
let execution_props = session_state.execution_props();
10391040
let join_on = keys
10401041
.iter()
10411042
.map(|(l, r)| {
1042-
let l = l.try_into_col()?;
1043-
let r = r.try_into_col()?;
1044-
Ok((
1045-
Column::new(&l.name, left_df_schema.index_of_column(&l)?),
1046-
Column::new(&r.name, right_df_schema.index_of_column(&r)?),
1047-
))
1043+
let l = create_physical_expr(
1044+
l,
1045+
left_df_schema,
1046+
execution_props
1047+
)?;
1048+
let r = create_physical_expr(
1049+
r,
1050+
right_df_schema,
1051+
execution_props
1052+
)?;
1053+
Ok((l, r))
10481054
})
10491055
.collect::<Result<join_utils::JoinOn>>()?;
10501056

datafusion/core/tests/fuzz_cases/join_fuzz.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ async fn run_join_test(
109109
let schema2 = input2[0].schema();
110110
let on_columns = vec![
111111
(
112-
Column::new_with_schema("a", &schema1).unwrap(),
113-
Column::new_with_schema("a", &schema2).unwrap(),
112+
Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
113+
Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
114114
),
115115
(
116-
Column::new_with_schema("b", &schema1).unwrap(),
117-
Column::new_with_schema("b", &schema2).unwrap(),
116+
Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
117+
Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
118118
),
119119
];
120120

0 commit comments

Comments
 (0)