Skip to content

[logical-types] use Scalar in Expr::Logical #12793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions datafusion-cli/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use datafusion::common::{plan_err, Column};
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::{Expr, Scalar};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::scalar::ScalarValue;
Expand Down Expand Up @@ -321,7 +321,10 @@ pub struct ParquetMetadataFunc {}
impl TableFunctionImpl for ParquetMetadataFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let filename = match exprs.first() {
Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet')
Some(Expr::Literal(Scalar {
value: ScalarValue::Utf8(Some(s)),
..
})) => s, // single quote: parquet_metadata('x.parquet')
Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet")
_ => {
return plan_err!(
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async fn main() -> Result<()> {
let expr2 = Expr::BinaryExpr(BinaryExpr::new(
Box::new(col("a")),
Operator::Plus,
Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))),
Box::new(Expr::from(ScalarValue::Int32(Some(5)))),
));
assert_eq!(expr, expr2);

Expand Down
14 changes: 11 additions & 3 deletions datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::{plan_err, ScalarValue};
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{Expr, TableType};
use datafusion_expr::{Expr, Scalar, TableType};
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use std::fs::File;
use std::io::Seek;
Expand Down Expand Up @@ -133,7 +133,11 @@ struct LocalCsvTableFunc {}

impl TableFunctionImpl for LocalCsvTableFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else {
let Some(Expr::Literal(Scalar {
value: ScalarValue::Utf8(Some(ref path)),
..
})) = exprs.first()
else {
return plan_err!("read_csv requires at least one string argument");
};

Expand All @@ -145,7 +149,11 @@ impl TableFunctionImpl for LocalCsvTableFunc {
let info = SimplifyContext::new(&execution_props);
let expr = ExprSimplifier::new(info).simplify(expr.clone())?;

if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr {
if let Expr::Literal(Scalar {
value: ScalarValue::Int64(Some(limit)),
..
}) = expr
{
Ok(limit as usize)
} else {
plan_err!("Limit must be an integer")
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/benches/map_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ fn criterion_benchmark(c: &mut Criterion) {
let mut value_buffer = Vec::new();

for i in 0..1000 {
key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone()))));
value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i]))));
key_buffer.push(Expr::from(ScalarValue::Utf8(Some(keys[i].clone()))));
value_buffer.push(Expr::from(ScalarValue::Int32(Some(values[i]))));
}
c.bench_function("map_1000_1", |b| {
b.iter(|| {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ impl DataFrame {
/// ```
pub async fn count(self) -> Result<usize> {
let rows = self
.aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])?
.aggregate(vec![], vec![count(Expr::from(COUNT_STAR_EXPANSION))])?
.collect()
.await?;
let len = *rows
Expand Down Expand Up @@ -2985,7 +2985,7 @@ mod tests {
let join = left.clone().join_on(
right.clone(),
JoinType::Inner,
Some(Expr::Literal(ScalarValue::Null)),
Some(Expr::from(ScalarValue::Null)),
)?;
let expected_plan = "CrossJoin:\
\n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ mod tests {
assert_eq!(
evaluate_partition_prefix(
partitions,
&[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))],
&[col("a").eq(Expr::from(ScalarValue::Date32(Some(3))))],
),
Some(Path::from("a=1970-01-04")),
);
Expand All @@ -877,7 +877,7 @@ mod tests {
assert_eq!(
evaluate_partition_prefix(
partitions,
&[col("a").eq(Expr::Literal(ScalarValue::Date64(Some(
&[col("a").eq(Expr::from(ScalarValue::Date64(Some(
4 * 24 * 60 * 60 * 1000
)))),],
),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,7 @@ mod tests {
let filter_predicate = Expr::BinaryExpr(BinaryExpr::new(
Box::new(Expr::Column("column1".into())),
Operator::GtEq,
Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))),
Box::new(Expr::from(ScalarValue::Int32(Some(0)))),
));

// Create a new batch of data to insert into the table
Expand Down
18 changes: 10 additions & 8 deletions datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ impl<'schema> TreeNodeRewriter for PushdownChecker<'schema> {
//
// See comments on `FilterCandidateBuilder` for more information
let null_value = ScalarValue::try_from(field.data_type())?;
Ok(Transformed::yes(Arc::new(Literal::new(null_value)) as _))
Ok(Transformed::yes(Arc::new(Literal::from(null_value)) as _))
})
// If the column is not in the table schema, should throw the error
.map_err(|e| arrow_datafusion_err!(e));
Expand Down Expand Up @@ -699,9 +699,10 @@ mod test {
.expect("expected error free record batch");

// Test all should fail
let expr = col("timestamp_col").lt(Expr::Literal(
ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))),
));
let expr = col("timestamp_col").lt(Expr::from(ScalarValue::TimestampNanosecond(
Some(1),
Some(Arc::from("UTC")),
)));
let expr = logical2physical(&expr, &table_schema);
let candidate = FilterCandidateBuilder::new(expr, &file_schema, &table_schema)
.build(&metadata)
Expand All @@ -723,9 +724,10 @@ mod test {
assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![false; 8])));

// Test all should pass
let expr = col("timestamp_col").gt(Expr::Literal(
ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))),
));
let expr = col("timestamp_col").gt(Expr::from(ScalarValue::TimestampNanosecond(
Some(0),
Some(Arc::from("UTC")),
)));
let expr = logical2physical(&expr, &table_schema);
let candidate = FilterCandidateBuilder::new(expr, &file_schema, &table_schema)
.build(&metadata)
Expand Down Expand Up @@ -826,7 +828,7 @@ mod test {

let expr = col("str_col")
.is_not_null()
.or(col("int_col").gt(Expr::Literal(ScalarValue::UInt64(Some(5)))));
.or(col("int_col").gt(Expr::from(ScalarValue::UInt64(Some(5)))));

assert!(can_expr_be_pushed_down_with_schemas(
&expr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1237,10 +1237,10 @@ mod tests {
.run(
lit("1").eq(lit("1")).and(
col(r#""String""#)
.eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from(
.eq(Expr::from(ScalarValue::Utf8View(Some(String::from(
"Hello_Not_Exists",
)))))
.or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View(
.or(col(r#""String""#).eq(Expr::from(ScalarValue::Utf8View(
Some(String::from("Hello_Not_Exists2")),
)))),
),
Expand Down Expand Up @@ -1322,15 +1322,15 @@ mod tests {
// generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")`
.run(
col(r#""String""#)
.eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from(
.eq(Expr::from(ScalarValue::Utf8View(Some(String::from(
"Hello",
)))))
.or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View(
Some(String::from("the quick")),
))))
.or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View(
Some(String::from("are you")),
)))),
.or(col(r#""String""#).eq(Expr::from(ScalarValue::Utf8View(Some(
String::from("the quick"),
)))))
.or(col(r#""String""#).eq(Expr::from(ScalarValue::Utf8View(Some(
String::from("are you"),
))))),
)
.await
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,7 @@ pub(crate) mod tests {
let predicate = Arc::new(BinaryExpr::new(
col("c", &schema()).unwrap(),
Operator::Eq,
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
Arc::new(Literal::from(ScalarValue::Int64(Some(0)))),
));
Arc::new(FilterExec::try_new(predicate, input).unwrap())
}
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,7 @@ mod tests {
Arc::new(Column::new("b_left_inter", 0)),
Operator::Minus,
Arc::new(BinaryExpr::new(
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
Arc::new(Literal::from(ScalarValue::Int32(Some(1)))),
Operator::Plus,
Arc::new(Column::new("a_right_inter", 1)),
)),
Expand Down Expand Up @@ -2301,7 +2301,7 @@ mod tests {
Arc::new(Column::new("b_left_inter", 0)),
Operator::Minus,
Arc::new(BinaryExpr::new(
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
Arc::new(Literal::from(ScalarValue::Int32(Some(1)))),
Operator::Plus,
Arc::new(Column::new("a_right_inter", 1)),
)),
Expand Down Expand Up @@ -2382,7 +2382,7 @@ mod tests {
Arc::new(Column::new("b", 7)),
Operator::Minus,
Arc::new(BinaryExpr::new(
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
Arc::new(Literal::from(ScalarValue::Int32(Some(1)))),
Operator::Plus,
Arc::new(Column::new("a", 1)),
)),
Expand Down Expand Up @@ -2410,7 +2410,7 @@ mod tests {
Arc::new(Column::new("b_left_inter", 0)),
Operator::Minus,
Arc::new(BinaryExpr::new(
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
Arc::new(Literal::from(ScalarValue::Int32(Some(1)))),
Operator::Plus,
Arc::new(Column::new("a_right_inter", 1)),
)),
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ impl BoolVecBuilder {
fn is_always_true(expr: &Arc<dyn PhysicalExpr>) -> bool {
expr.as_any()
.downcast_ref::<phys_expr::Literal>()
.map(|l| matches!(l.value(), ScalarValue::Boolean(Some(true))))
.map(|l| matches!(l.scalar().value(), ScalarValue::Boolean(Some(true))))
.unwrap_or_default()
}

Expand Down Expand Up @@ -1300,7 +1300,7 @@ fn build_is_null_column_expr(
Arc::new(phys_expr::BinaryExpr::new(
null_count_column_expr,
Operator::Gt,
Arc::new(phys_expr::Literal::new(ScalarValue::UInt64(Some(0)))),
Arc::new(phys_expr::Literal::from(ScalarValue::UInt64(Some(0)))),
)) as _
})
.ok()
Expand Down Expand Up @@ -1328,7 +1328,7 @@ fn build_predicate_expression(
) -> Arc<dyn PhysicalExpr> {
// Returned for unsupported expressions. Such expressions are
// converted to TRUE.
let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));
let unhandled = Arc::new(phys_expr::Literal::from(ScalarValue::Boolean(Some(true))));

// predicate expression can only be a binary expression
let expr_any = expr.as_any();
Expand Down Expand Up @@ -1549,7 +1549,7 @@ fn wrap_case_expr(
Operator::Eq,
expr_builder.row_count_column_expr()?,
));
let then = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false))));
let then = Arc::new(phys_expr::Literal::from(ScalarValue::Boolean(Some(false))));

// CASE WHEN x_null_count = x_row_count THEN false ELSE <statistics_expr> END
Ok(Arc::new(phys_expr::CaseExpr::try_new(
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ fn get_null_physical_expr_pair(
let data_type = physical_expr.data_type(input_schema)?;
let null_value: ScalarValue = (&data_type).try_into()?;

let null_value = Literal::new(null_value);
let null_value = Literal::from(null_value);
Ok((Arc::new(null_value), physical_name))
}

Expand Down Expand Up @@ -2018,7 +2018,7 @@ mod tests {
// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }";
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { scalar: Int64(5) }, fail_on_overflow: false }";
assert!(format!("{exec_plan:?}").contains(expected));
Ok(())
}
Expand All @@ -2043,7 +2043,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { scalar: Utf8(NULL) }, "c1"), (Literal { scalar: Int64(NULL) }, "c2"), (Literal { scalar: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;

assert_eq!(format!("{cube:?}"), expected);

Expand All @@ -2070,7 +2070,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { scalar: Utf8(NULL) }, "c1"), (Literal { scalar: Int64(NULL) }, "c2"), (Literal { scalar: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;

assert_eq!(format!("{rollup:?}"), expected);

Expand Down Expand Up @@ -2254,7 +2254,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.

let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }";
let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { scalar: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { scalar: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }";

let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,17 @@ impl TableProvider for CustomProvider {
match &filters[0] {
Expr::BinaryExpr(BinaryExpr { right, .. }) => {
let int_value = match &**right {
Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i))) => *i,
Expr::Literal(lit_value) => match lit_value.value() {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
ScalarValue::Int64(Some(v)) => *v,
other_value => {
return not_impl_err!("Do not support value {other_value:?}");
}
},
Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() {
Expr::Literal(lit_value) => match lit_value {
Expr::Literal(lit_value) => match lit_value.value() {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,7 @@ async fn consecutive_projection_same_schema() -> Result<()> {

// Add `t` column full of nulls
let df = df
.with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32))
.with_column("t", cast(Expr::from(ScalarValue::Null), DataType::Int32))
.unwrap();
df.clone().show().await.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ fn select_date_plus_interval() -> Result<()> {

let date_plus_interval_expr = to_timestamp_expr(ts_string)
.cast_to(&DataType::Date32, schema)?
+ Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime {
+ Expr::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime {
days: 123,
milliseconds: 0,
})));
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ impl JoinFuzzTestCase {
filter.schema().fields().len(),
)
} else {
(Arc::new(Literal::new(ScalarValue::from(true))) as _, 0)
(Arc::new(Literal::from(ScalarValue::from(true))) as _, 0)
};

let equal_a = Arc::new(BinaryExpr::new(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async fn parquet_partition_pruning_filter() -> Result<()> {
let expected = Arc::new(BinaryExpr::new(
Arc::new(Column::new_with_schema("id", &exec.schema()).unwrap()),
Operator::Gt,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
Arc::new(Literal::from(ScalarValue::Int32(Some(1)))),
));

assert!(pred.as_any().is::<BinaryExpr>());
Expand Down
Loading