Skip to content

Commit e7ac843

Browse files
Bug-fix: MemoryExec sort expressions do NOT refer to the projected schema (#12876)
* Update memory.rs * add assert * Update memory.rs * Update memory.rs * Update memory.rs * address review * Update memory.rs * Update memory.rs * final fix * Fix comments in test_utils.rs --------- Co-authored-by: Mehmet Ozan Kabak <[email protected]>
1 parent 6267ede commit e7ac843

File tree

10 files changed

+78
-24
lines changed

10 files changed

+78
-24
lines changed

datafusion/core/src/datasource/memory.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ use crate::physical_planner::create_physical_sort_exprs;
3737

3838
use arrow::datatypes::SchemaRef;
3939
use arrow::record_batch::RecordBatch;
40+
use datafusion_catalog::Session;
4041
use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt};
4142
use datafusion_execution::TaskContext;
4243
use datafusion_expr::dml::InsertOp;
44+
use datafusion_expr::SortExpr;
4345
use datafusion_physical_plan::metrics::MetricsSet;
4446

4547
use async_trait::async_trait;
46-
use datafusion_catalog::Session;
47-
use datafusion_expr::SortExpr;
4848
use futures::StreamExt;
4949
use log::debug;
5050
use parking_lot::Mutex;
@@ -241,7 +241,7 @@ impl TableProvider for MemTable {
241241
)
242242
})
243243
.collect::<Result<Vec<_>>>()?;
244-
exec = exec.with_sort_information(file_sort_order);
244+
exec = exec.try_with_sort_information(file_sort_order)?;
245245
}
246246

247247
Ok(Arc::new(exec))

datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,8 @@ async fn run_aggregate_test(input1: Vec<RecordBatch>, group_by_columns: Vec<&str
395395
let running_source = Arc::new(
396396
MemoryExec::try_new(&[input1.clone()], schema.clone(), None)
397397
.unwrap()
398-
.with_sort_information(vec![sort_keys]),
398+
.try_with_sort_information(vec![sort_keys])
399+
.unwrap(),
399400
);
400401

401402
let aggregate_expr =

datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,8 @@ mod sp_repartition_fuzz_tests {
358358
let running_source = Arc::new(
359359
MemoryExec::try_new(&[input1.clone()], schema.clone(), None)
360360
.unwrap()
361-
.with_sort_information(vec![sort_keys.clone()]),
361+
.try_with_sort_information(vec![sort_keys.clone()])
362+
.unwrap(),
362363
);
363364
let hash_exprs = vec![col("c", &schema).unwrap()];
364365

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ async fn run_window_test(
647647
];
648648
let mut exec1 = Arc::new(
649649
MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None)?
650-
.with_sort_information(vec![source_sort_keys.clone()]),
650+
.try_with_sort_information(vec![source_sort_keys.clone()])?,
651651
) as _;
652652
// Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a
653653
// For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort.
@@ -673,7 +673,7 @@ async fn run_window_test(
673673
)?) as _;
674674
let exec2 = Arc::new(
675675
MemoryExec::try_new(&[input1.clone()], schema.clone(), None)?
676-
.with_sort_information(vec![source_sort_keys.clone()]),
676+
.try_with_sort_information(vec![source_sort_keys.clone()])?,
677677
);
678678
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
679679
vec![create_window_expr(

datafusion/core/tests/memory_limit/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ impl TableProvider for SortedTableProvider {
840840
) -> Result<Arc<dyn ExecutionPlan>> {
841841
let mem_exec =
842842
MemoryExec::try_new(&self.batches, self.schema(), projection.cloned())?
843-
.with_sort_information(self.sort_information.clone());
843+
.try_with_sort_information(self.sort_information.clone())?;
844844

845845
Ok(Arc::new(mem_exec))
846846
}

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ mod tests {
780780
};
781781
sort_info.push(sort_expr);
782782
}
783-
exec = exec.with_sort_information(vec![sort_info]);
783+
exec = exec.try_with_sort_information(vec![sort_info]).unwrap();
784784
}
785785

786786
Arc::new(exec)

datafusion/physical-plan/src/joins/test_utils.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ macro_rules! join_expr_tests {
289289
ScalarValue::$SCALAR(Some(10 as $type)),
290290
(Operator::Gt, Operator::Lt),
291291
),
292-
// left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10
292+
// left_col - 1 > right_col + 3 AND left_col + 3 < right_col + 15
293293
1 => gen_conjunctive_numerical_expr(
294294
left_col,
295295
right_col,
@@ -300,9 +300,9 @@ macro_rules! join_expr_tests {
300300
Operator::Plus,
301301
),
302302
ScalarValue::$SCALAR(Some(1 as $type)),
303-
ScalarValue::$SCALAR(Some(5 as $type)),
304303
ScalarValue::$SCALAR(Some(3 as $type)),
305-
ScalarValue::$SCALAR(Some(10 as $type)),
304+
ScalarValue::$SCALAR(Some(3 as $type)),
305+
ScalarValue::$SCALAR(Some(15 as $type)),
306306
(Operator::Gt, Operator::Lt),
307307
),
308308
// left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10
@@ -353,7 +353,8 @@ macro_rules! join_expr_tests {
353353
ScalarValue::$SCALAR(Some(3 as $type)),
354354
(Operator::Gt, Operator::Lt),
355355
),
356-
// left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3
356+
// left_col - 2 >= right_col + 5 AND left_col + 7 <= right_col - 3
357+
// (filters all input rows)
357358
5 => gen_conjunctive_numerical_expr(
358359
left_col,
359360
right_col,
@@ -369,7 +370,7 @@ macro_rules! join_expr_tests {
369370
ScalarValue::$SCALAR(Some(3 as $type)),
370371
(Operator::GtEq, Operator::LtEq),
371372
),
372-
// left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39
373+
// left_col + 28 >= right_col - 11 AND left_col + 21 <= right_col + 39
373374
6 => gen_conjunctive_numerical_expr(
374375
left_col,
375376
right_col,
@@ -385,7 +386,7 @@ macro_rules! join_expr_tests {
385386
ScalarValue::$SCALAR(Some(39 as $type)),
386387
(Operator::Gt, Operator::LtEq),
387388
),
388-
// left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col + 39
389+
// left_col + 28 >= right_col - 11 AND left_col - 21 <= right_col + 39
389390
7 => gen_conjunctive_numerical_expr(
390391
left_col,
391392
right_col,
@@ -526,10 +527,10 @@ pub fn create_memory_table(
526527
) -> Result<(Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>)> {
527528
let left_schema = left_partition[0].schema();
528529
let left = MemoryExec::try_new(&[left_partition], left_schema, None)?
529-
.with_sort_information(left_sorted);
530+
.try_with_sort_information(left_sorted)?;
530531
let right_schema = right_partition[0].schema();
531532
let right = MemoryExec::try_new(&[right_partition], right_schema, None)?
532-
.with_sort_information(right_sorted);
533+
.try_with_sort_information(right_sorted)?;
533534
Ok((Arc::new(left), Arc::new(right)))
534535
}
535536

datafusion/physical-plan/src/memory.rs

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ use arrow::record_batch::RecordBatch;
3333
use datafusion_common::{internal_err, project_schema, Result};
3434
use datafusion_execution::memory_pool::MemoryReservation;
3535
use datafusion_execution::TaskContext;
36+
use datafusion_physical_expr::equivalence::ProjectionMapping;
37+
use datafusion_physical_expr::expressions::Column;
38+
use datafusion_physical_expr::utils::collect_columns;
3639
use datafusion_physical_expr::{EquivalenceProperties, LexOrdering};
3740

3841
use futures::Stream;
@@ -206,16 +209,63 @@ impl MemoryExec {
206209
/// where both `a ASC` and `b DESC` can describe the table ordering. With
207210
/// [`EquivalenceProperties`], we can keep track of these equivalences
208211
/// and treat `a ASC` and `b DESC` as the same ordering requirement.
209-
pub fn with_sort_information(mut self, sort_information: Vec<LexOrdering>) -> Self {
210-
self.sort_information = sort_information;
212+
///
213+
/// Note that if there is an internal projection, that projection will be
214+
/// also applied to the given `sort_information`.
215+
pub fn try_with_sort_information(
216+
mut self,
217+
mut sort_information: Vec<LexOrdering>,
218+
) -> Result<Self> {
219+
// All sort expressions must refer to the original schema
220+
let fields = self.schema.fields();
221+
let ambiguous_column = sort_information
222+
.iter()
223+
.flatten()
224+
.flat_map(|expr| collect_columns(&expr.expr))
225+
.find(|col| {
226+
fields
227+
.get(col.index())
228+
.map(|field| field.name() != col.name())
229+
.unwrap_or(true)
230+
});
231+
if let Some(col) = ambiguous_column {
232+
return internal_err!(
233+
"Column {:?} is not found in the original schema of the MemoryExec",
234+
col
235+
);
236+
}
237+
238+
// If there is a projection on the source, we also need to project orderings
239+
if let Some(projection) = &self.projection {
240+
let base_eqp = EquivalenceProperties::new_with_orderings(
241+
self.original_schema(),
242+
&sort_information,
243+
);
244+
let proj_exprs = projection
245+
.iter()
246+
.map(|idx| {
247+
let base_schema = self.original_schema();
248+
let name = base_schema.field(*idx).name();
249+
(Arc::new(Column::new(name, *idx)) as _, name.to_string())
250+
})
251+
.collect::<Vec<_>>();
252+
let projection_mapping =
253+
ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?;
254+
sort_information = base_eqp
255+
.project(&projection_mapping, self.schema())
256+
.oeq_class
257+
.orderings;
258+
}
211259

260+
self.sort_information = sort_information;
212261
// We need to update equivalence properties when updating sort information.
213262
let eq_properties = EquivalenceProperties::new_with_orderings(
214263
self.schema(),
215264
&self.sort_information,
216265
);
217266
self.cache = self.cache.with_eq_properties(eq_properties);
218-
self
267+
268+
Ok(self)
219269
}
220270

221271
pub fn original_schema(&self) -> SchemaRef {
@@ -347,7 +397,7 @@ mod tests {
347397

348398
let sort_information = vec![sort1.clone(), sort2.clone()];
349399
let mem_exec = MemoryExec::try_new(&[vec![]], schema, None)?
350-
.with_sort_information(sort_information);
400+
.try_with_sort_information(sort_information)?;
351401

352402
assert_eq!(
353403
mem_exec.properties().output_ordering().unwrap(),

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1677,7 +1677,8 @@ mod test {
16771677
Arc::new(
16781678
MemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
16791679
.unwrap()
1680-
.with_sort_information(vec![sort_exprs]),
1680+
.try_with_sort_information(vec![sort_exprs])
1681+
.unwrap(),
16811682
)
16821683
}
16831684
}

datafusion/physical-plan/src/union.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,11 +809,11 @@ mod tests {
809809
.collect::<Vec<_>>();
810810
let child1 = Arc::new(
811811
MemoryExec::try_new(&[], Arc::clone(&schema), None)?
812-
.with_sort_information(first_orderings),
812+
.try_with_sort_information(first_orderings)?,
813813
);
814814
let child2 = Arc::new(
815815
MemoryExec::try_new(&[], Arc::clone(&schema), None)?
816-
.with_sort_information(second_orderings),
816+
.try_with_sort_information(second_orderings)?,
817817
);
818818

819819
let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));

0 commit comments

Comments
 (0)