diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index c622463de033..3b91abf8f3dc 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -27,7 +27,7 @@ use datafusion::common::{plan_err, Column}; use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; -use datafusion::logical_expr::Expr; +use datafusion::logical_expr::{Expr, Scalar}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::scalar::ScalarValue; @@ -321,7 +321,10 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { let filename = match exprs.first() { - Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Literal(Scalar { + value: ScalarValue::Utf8(Some(s)), + .. + })) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") _ => { return plan_err!( diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 0eb823302acf..85a79a1a5604 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -61,7 +61,7 @@ async fn main() -> Result<()> { let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, - Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))), + Box::new(Expr::from(ScalarValue::Int32(Some(5)))), )); assert_eq!(expr, expr2); diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index 6faa397ef60f..8cc4309d4d31 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -30,7 +30,7 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{Expr, TableType}; +use datafusion_expr::{Expr, Scalar, TableType}; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use std::fs::File; use std::io::Seek; @@ -133,7 +133,11 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + let Some(Expr::Literal(Scalar { + value: ScalarValue::Utf8(Some(ref path)), + .. + })) = exprs.first() + else { return plan_err!("read_csv requires at least one string argument"); }; @@ -145,7 +149,11 @@ impl TableFunctionImpl for LocalCsvTableFunc { let info = SimplifyContext::new(&execution_props); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; - if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + if let Expr::Literal(Scalar { + value: ScalarValue::Int64(Some(limit)), + .. + }) = expr + { Ok(limit as usize) } else { plan_err!("Limit must be an integer") diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index e4c5f7c5deb3..59e0873b7174 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -71,8 +71,8 @@ fn criterion_benchmark(c: &mut Criterion) { let mut value_buffer = Vec::new(); for i in 0..1000 { - key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + key_buffer.push(Expr::from(ScalarValue::Utf8(Some(keys[i].clone())))); + value_buffer.push(Expr::from(ScalarValue::Int32(Some(values[i])))); } c.bench_function("map_1000_1", |b| { b.iter(|| { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 67e2a4780d06..72cffd2e7c39 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1188,7 +1188,7 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? + .aggregate(vec![], vec![count(Expr::from(COUNT_STAR_EXPANSION))])? .collect() .await?; let len = *rows @@ -2985,7 +2985,7 @@ mod tests { let join = left.clone().join_on( right.clone(), JoinType::Inner, - Some(Expr::Literal(ScalarValue::Null)), + Some(Expr::from(ScalarValue::Null)), )?; let expected_plan = "CrossJoin:\ \n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\ diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 72d7277d6ae2..e18fb8fc7ba3 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -868,7 +868,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + &[col("a").eq(Expr::from(ScalarValue::Date32(Some(3))))], ), Some(Path::from("a=1970-01-04")), ); @@ -877,7 +877,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( + &[col("a").eq(Expr::from(ScalarValue::Date64(Some( 4 * 24 * 60 * 60 * 1000 )))),], ), diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a9c6aec17537..0da146c558a0 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1941,7 +1941,7 @@ mod tests { let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( Box::new(Expr::Column("column1".into())), Operator::GtEq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + Box::new(Expr::from(ScalarValue::Int32(Some(0)))), )); // Create a new batch of data to insert into the table diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index e876f840d1eb..b161d55f2e6e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -373,7 +373,7 @@ impl<'schema> TreeNodeRewriter for PushdownChecker<'schema> { // // See comments on `FilterCandidateBuilder` for more information let null_value = ScalarValue::try_from(field.data_type())?; - Ok(Transformed::yes(Arc::new(Literal::new(null_value)) as _)) + Ok(Transformed::yes(Arc::new(Literal::from(null_value)) as _)) }) // If the column is not in the table schema, should throw the error .map_err(|e| arrow_datafusion_err!(e)); @@ -699,9 +699,10 @@ mod test { .expect("expected error free record batch"); // Test all should fail - let expr = col("timestamp_col").lt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), - )); + let expr = col("timestamp_col").lt(Expr::from(ScalarValue::TimestampNanosecond( + Some(1), + Some(Arc::from("UTC")), + ))); let expr = logical2physical(&expr, &table_schema); let candidate = FilterCandidateBuilder::new(expr, &file_schema, &table_schema) .build(&metadata) @@ -723,9 +724,10 @@ mod test { assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![false; 8]))); // Test all should pass - let expr = col("timestamp_col").gt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), - )); + let expr = col("timestamp_col").gt(Expr::from(ScalarValue::TimestampNanosecond( + Some(0), + Some(Arc::from("UTC")), + ))); let expr = logical2physical(&expr, &table_schema); let candidate = FilterCandidateBuilder::new(expr, &file_schema, &table_schema) .build(&metadata) @@ -826,7 +828,7 @@ mod test { let expr = col("str_col") .is_not_null() - .or(col("int_col").gt(Expr::Literal(ScalarValue::UInt64(Some(5))))); + .or(col("int_col").gt(Expr::from(ScalarValue::UInt64(Some(5))))); assert!(can_expr_be_pushed_down_with_schemas( &expr, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index a1d74cb54355..1e5a0195449d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -1237,10 +1237,10 @@ mod tests { .run( lit("1").eq(lit("1")).and( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( + .eq(Expr::from(ScalarValue::Utf8View(Some(String::from( "Hello_Not_Exists", ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( + .or(col(r#""String""#).eq(Expr::from(ScalarValue::Utf8View( Some(String::from("Hello_Not_Exists2")), )))), ), @@ -1322,15 +1322,15 @@ mod tests { // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` .run( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( + .eq(Expr::from(ScalarValue::Utf8View(Some(String::from( "Hello", ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("the quick")), - )))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("are you")), - )))), + .or(col(r#""String""#).eq(Expr::from(ScalarValue::Utf8View(Some( + String::from("the quick"), + ))))) + .or(col(r#""String""#).eq(Expr::from(ScalarValue::Utf8View(Some( + String::from("are you"), + ))))), ) .await } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index c971e6150633..b64d1d5a83f8 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1722,7 +1722,7 @@ pub(crate) mod tests { let predicate = Arc::new(BinaryExpr::new( col("c", &schema()).unwrap(), Operator::Eq, - Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + Arc::new(Literal::from(ScalarValue::Int64(Some(0)))), )); Arc::new(FilterExec::try_new(predicate, input).unwrap()) } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index b4dd0a995d5f..7ac70d701cf6 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -2187,7 +2187,7 @@ mod tests { Arc::new(Column::new("b_left_inter", 0)), Operator::Minus, Arc::new(BinaryExpr::new( - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), Operator::Plus, Arc::new(Column::new("a_right_inter", 1)), )), @@ -2301,7 +2301,7 @@ mod tests { Arc::new(Column::new("b_left_inter", 0)), Operator::Minus, Arc::new(BinaryExpr::new( - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), Operator::Plus, Arc::new(Column::new("a_right_inter", 1)), )), @@ -2382,7 +2382,7 @@ mod tests { Arc::new(Column::new("b", 7)), Operator::Minus, Arc::new(BinaryExpr::new( - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), Operator::Plus, Arc::new(Column::new("a", 1)), )), @@ -2410,7 +2410,7 @@ mod tests { Arc::new(Column::new("b_left_inter", 0)), Operator::Minus, Arc::new(BinaryExpr::new( - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), Operator::Plus, Arc::new(Column::new("a_right_inter", 1)), )), diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 3c8e5ddd1c74..e88b25cc2363 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -714,7 +714,7 @@ impl BoolVecBuilder { fn is_always_true(expr: &Arc) -> bool { expr.as_any() .downcast_ref::() - .map(|l| matches!(l.value(), ScalarValue::Boolean(Some(true)))) + .map(|l| matches!(l.scalar().value(), ScalarValue::Boolean(Some(true)))) .unwrap_or_default() } @@ -1300,7 +1300,7 @@ fn build_is_null_column_expr( Arc::new(phys_expr::BinaryExpr::new( null_count_column_expr, Operator::Gt, - Arc::new(phys_expr::Literal::new(ScalarValue::UInt64(Some(0)))), + Arc::new(phys_expr::Literal::from(ScalarValue::UInt64(Some(0)))), )) as _ }) .ok() @@ -1328,7 +1328,7 @@ fn build_predicate_expression( ) -> Arc { // Returned for unsupported expressions. Such expressions are // converted to TRUE. - let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))); + let unhandled = Arc::new(phys_expr::Literal::from(ScalarValue::Boolean(Some(true)))); // predicate expression can only be a binary expression let expr_any = expr.as_any(); @@ -1549,7 +1549,7 @@ fn wrap_case_expr( Operator::Eq, expr_builder.row_count_column_expr()?, )); - let then = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))); + let then = Arc::new(phys_expr::Literal::from(ScalarValue::Boolean(Some(false)))); // CASE WHEN x_null_count = x_row_count THEN false ELSE END Ok(Arc::new(phys_expr::CaseExpr::try_new( diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index cf2a157b04b6..dac0634316f3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1423,7 +1423,7 @@ fn get_null_physical_expr_pair( let data_type = physical_expr.data_type(input_schema)?; let null_value: ScalarValue = (&data_type).try_into()?; - let null_value = Literal::new(null_value); + let null_value = Literal::from(null_value); Ok((Arc::new(null_value), physical_name)) } @@ -2018,7 +2018,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; + let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { scalar: Int64(5) }, fail_on_overflow: false }"; assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2043,7 +2043,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { scalar: Utf8(NULL) }, "c1"), (Literal { scalar: Int64(NULL) }, "c2"), (Literal { scalar: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; assert_eq!(format!("{cube:?}"), expected); @@ -2070,7 +2070,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { scalar: Utf8(NULL) }, "c1"), (Literal { scalar: Int64(NULL) }, "c2"), (Literal { scalar: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; assert_eq!(format!("{rollup:?}"), expected); @@ -2254,7 +2254,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { scalar: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { scalar: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 09f7265d639a..ca390ada2b69 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -174,12 +174,17 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int64(Some(i))) => *i, + Expr::Literal(lit_value) => match lit_value.value() { + ScalarValue::Int8(Some(v)) => *v as i64, + ScalarValue::Int16(Some(v)) => *v as i64, + ScalarValue::Int32(Some(v)) => *v as i64, + ScalarValue::Int64(Some(v)) => *v, + other_value => { + return not_impl_err!("Do not support value {other_value:?}"); + } + }, Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { - Expr::Literal(lit_value) => match lit_value { + Expr::Literal(lit_value) => match lit_value.value() { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 3520ab8fed2b..41219c7116ce 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1634,7 +1634,7 @@ async fn consecutive_projection_same_schema() -> Result<()> { // Add `t` column full of nulls let df = df - .with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32)) + .with_column("t", cast(Expr::from(ScalarValue::Null), DataType::Int32)) .unwrap(); df.clone().show().await.unwrap(); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index d7995d4663be..a279f516b996 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -282,7 +282,7 @@ fn select_date_plus_interval() -> Result<()> { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema)? - + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + + Expr::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 123, milliseconds: 0, }))); diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 96aa1be181f5..2535b94eaec4 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -329,7 +329,7 @@ impl JoinFuzzTestCase { filter.schema().fields().len(), ) } else { - (Arc::new(Literal::new(ScalarValue::from(true))) as _, 0) + (Arc::new(Literal::from(ScalarValue::from(true))) as _, 0) }; let equal_a = Arc::new(BinaryExpr::new( diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 919054e8330f..0c983bd732d0 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -91,7 +91,7 @@ async fn parquet_partition_pruning_filter() -> Result<()> { let expected = Arc::new(BinaryExpr::new( Arc::new(Column::new_with_schema("id", &exec.schema()).unwrap()), Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), )); assert!(pred.as_any().is::()); diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index ad9c1280d6b1..038d745a7fdb 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -55,7 +55,7 @@ impl ExprPlanner for MyCustomPlanner { } BinaryOperator::Question => { Ok(PlannerResult::Planned(Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), + Expr::from(ScalarValue::Boolean(Some(true))), None::<&str>, format!("{} ? {}", expr.left, expr.right), )))) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 2b45d0ed600b..d9965c20890a 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -98,7 +98,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::tree_node::replace_sort_expression; -use datafusion_expr::{Projection, SortExpr}; +use datafusion_expr::{Projection, Scalar, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; @@ -728,9 +728,12 @@ impl MyAnalyzerRule { .map(|e| { e.transform(|e| { Ok(match e { - Expr::Literal(ScalarValue::Int64(i)) => { + Expr::Literal(Scalar { + value: ScalarValue::Int64(i), + .. + }) => { // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( + Transformed::yes(Expr::from(ScalarValue::UInt64( i.map(|i| i as u64), ))) } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 2f47d78db2d5..a6e7847d463b 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -42,7 +42,7 @@ use datafusion_common::{ use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, ExprSchemable, - LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature, + LogicalPlanBuilder, OperateFunctionArg, Scalar, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; @@ -648,8 +648,10 @@ impl ScalarUDFImpl for TakeUDF { return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len()); } - let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) = - arg_exprs.get(2) + let take_idx = if let Some(Expr::Literal(Scalar { + value: ScalarValue::Int64(Some(idx)), + .. + })) = arg_exprs.get(2) { if *idx == 0 || *idx == 1 { *idx as usize diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 0cc156866d4d..bf5c822be240 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -30,7 +30,7 @@ use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::SessionContext; use datafusion_catalog::Session; use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue}; -use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; +use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, Scalar, TableType}; use std::fs::File; use std::io::Seek; use std::path::Path; @@ -201,7 +201,10 @@ impl TableFunctionImpl for SimpleCsvTableFunc { let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + Expr::Literal(Scalar { + value: ScalarValue::Utf8(Some(ref path)), + .. + }) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), diff --git a/datafusion/expr-common/src/scalar.rs b/datafusion/expr-common/src/scalar.rs index d4bb04e1f94c..a1e2cc40bb27 100644 --- a/datafusion/expr-common/src/scalar.rs +++ b/datafusion/expr-common/src/scalar.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use core::fmt; +use std::hash::Hash; + use arrow::{ array::{Array, ArrayRef}, datatypes::DataType, @@ -40,9 +43,9 @@ use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; /// - [`Scalar::new_null_of`] /// /// [`Scalar`] is meant to be stored as a variant of [`crate::columnar_value::ColumnarValue`]. -#[derive(Clone, Debug)] +#[derive(Clone, Eq)] pub struct Scalar { - value: ScalarValue, + pub value: ScalarValue, data_type: DataType, } @@ -76,6 +79,30 @@ impl PartialEq for Scalar { } } +impl PartialOrd for Scalar { + fn partial_cmp(&self, other: &Self) -> Option { + self.value.partial_cmp(&other.value) + } +} + +impl Hash for Scalar { + fn hash(&self, state: &mut H) { + self.value.hash(state); + } +} + +impl fmt::Display for Scalar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.value.fmt(f) + } +} + +impl fmt::Debug for Scalar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.value.fmt(f) + } +} + impl Scalar { /// Converts a value in `array` at `index` into a Scalar, resolving the /// [`DataType`] from the `array`'s data type. @@ -104,7 +131,7 @@ impl Scalar { /// Converts an iterator of references [`Scalar`] into an [`ArrayRef`] /// corresponding to those values. - pub fn iter_to_array(scalars: impl IntoIterator) -> Result { + pub fn iter_to_array(scalars: impl IntoIterator) -> Result { let mut scalars = scalars.into_iter().peekable(); // figure out the type based on the first element @@ -116,6 +143,10 @@ impl Scalar { ScalarValue::iter_to_array_of_type(scalars.map(|scalar| scalar.value), &data_type) } + pub fn cast_to(&self, data_type: &DataType) -> Result { + Ok(Self::from(self.value.cast_to(data_type)?)) + } + #[inline] pub fn value(&self) -> &ScalarValue { &self.value diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 215efb59526f..93c7703ed956 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -40,6 +40,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ plan_err, Column, DFSchema, Result, ScalarValue, TableReference, }; +use datafusion_expr_common::scalar::Scalar; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use sqlparser::ast::{ display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, @@ -54,7 +55,7 @@ use sqlparser::ast::{ /// BinaryExpr { /// left: Expr::Column("A"), /// op: Operator::Plus, -/// right: Expr::Literal(ScalarValue::Int32(Some(1))) +/// right: Expr::from(ScalarValue::Int32(Some(1))) /// } /// ``` /// @@ -107,9 +108,9 @@ use sqlparser::ast::{ /// # use datafusion_expr::{lit, col, Expr}; /// // All literals are strongly typed in DataFusion. To make an `i64` 42: /// let expr = lit(42i64); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); +/// assert_eq!(expr, Expr::from(ScalarValue::Int64(Some(42)))); /// // To make a (typed) NULL: -/// let expr = Expr::Literal(ScalarValue::Int64(None)); +/// let expr = Expr::from(ScalarValue::Int64(None)); /// // to make an (untyped) NULL (the optimizer will coerce this to the correct type): /// let expr = lit(ScalarValue::Null); /// ``` @@ -143,7 +144,7 @@ use sqlparser::ast::{ /// if let Expr::BinaryExpr(binary_expr) = expr { /// assert_eq!(*binary_expr.left, col("c1")); /// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*binary_expr.right, Expr::Literal(scalar)); +/// assert_eq!(*binary_expr.right, Expr::from(scalar)); /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` @@ -190,7 +191,7 @@ use sqlparser::ast::{ /// // apply recursively visits all nodes in the expression tree /// expr.apply(|e| { /// if let Expr::Literal(scalar) = e { -/// scalars.insert(scalar); +/// scalars.insert(scalar.value()); /// } /// // The return value controls whether to continue visiting the tree /// Ok(TreeNodeRecursion::Continue) @@ -233,7 +234,7 @@ pub enum Expr { /// A named reference to a variable in a registry. ScalarVariable(DataType, Vec), /// A constant value. - Literal(ScalarValue), + Literal(Scalar), /// A binary expression such as "age > 21" BinaryExpr(BinaryExpr), /// LIKE expression @@ -333,7 +334,7 @@ pub enum Expr { impl Default for Expr { fn default() -> Self { - Expr::Literal(ScalarValue::Null) + Expr::from(ScalarValue::Null) } } @@ -2409,7 +2410,7 @@ mod test { #[allow(deprecated)] fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), + expr: Box::new(Expr::from(ScalarValue::Float32(Some(1.23)))), data_type: DataType::Utf8, }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 2975e36488dc..4b30def27630 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -673,17 +673,17 @@ impl WindowUDFImpl for SimpleWindowUDF { pub fn interval_year_month_lit(value: &str) -> Expr { let interval = parse_interval_year_month(value).ok(); - Expr::Literal(ScalarValue::IntervalYearMonth(interval)) + Expr::from(ScalarValue::IntervalYearMonth(interval)) } pub fn interval_datetime_lit(value: &str) -> Expr { let interval = parse_interval_day_time(value).ok(); - Expr::Literal(ScalarValue::IntervalDayTime(interval)) + Expr::from(ScalarValue::IntervalDayTime(interval)) } pub fn interval_month_day_nano_lit(value: &str) -> Expr { let interval = parse_interval_month_day_nano(value).ok(); - Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) + Expr::from(ScalarValue::IntervalMonthDayNano(interval)) } /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 15930914dd59..a304f43e6bee 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -380,14 +380,13 @@ mod test { // rewrites all "foo" string literals to "bar" let transformer = |expr: Expr| -> Result> { match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { - let utf8_val = if utf8_val == "foo" { - "bar".to_string() - } else { - utf8_val - }; - Ok(Transformed::yes(lit(utf8_val))) - } + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(utf8_val)) => { + let utf8_val = if utf8_val == "foo" { "bar" } else { utf8_val }; + Ok(Transformed::yes(lit(utf8_val))) + } + _ => Ok(Transformed::no(Expr::Literal(scalar))), + }, // otherwise, return None _ => Ok(Transformed::no(expr)), } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index ad617c53d617..03c63e9b2de7 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -112,7 +112,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.data_type()), + Expr::Literal(l) => Ok(l.data_type().clone()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { let then_type = then_expr.get_type(schema)?; @@ -277,7 +277,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => input_schema.nullable(c), Expr::OuterReferenceColumn(_, _) => Ok(true), - Expr::Literal(value) => Ok(value.is_null()), + Expr::Literal(value) => Ok(value.value().is_null()), Expr::Case(case) => { // this expression is nullable if any of the input expressions are nullable let then_nullable = case @@ -381,7 +381,7 @@ impl ExprSchemable for Expr { .map(|(d, n)| (d.clone(), n)), Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), - Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + Expr::Literal(l) => Ok((l.data_type().clone(), l.value().is_null())), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index 90ba5a9a693c..2904cf13a1f0 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -19,6 +19,7 @@ use crate::Expr; use datafusion_common::ScalarValue; +use datafusion_expr_common::scalar::Scalar; /// Create a literal expression pub fn lit(n: T) -> Expr { @@ -41,39 +42,45 @@ pub trait TimestampLiteral { fn lit_timestamp_nano(&self) -> Expr; } +impl From for Expr { + fn from(value: ScalarValue) -> Self { + Self::Literal(Scalar::from(value)) + } +} + impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(*self)) + Expr::from(ScalarValue::from(*self)) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::from(ScalarValue::from(self.as_ref())) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::from(ScalarValue::from(self.as_ref())) } } impl Literal for Vec { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::from(ScalarValue::Binary(Some((*self).to_owned()))) } } impl Literal for &[u8] { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::from(ScalarValue::Binary(Some((*self).to_owned()))) } } impl Literal for ScalarValue { fn lit(&self) -> Expr { - Expr::Literal(self.clone()) + Expr::from(self.clone()) } } @@ -82,7 +89,7 @@ macro_rules! make_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + Expr::from(ScalarValue::$SCALAR(Some(self.clone()))) } } }; @@ -93,7 +100,7 @@ macro_rules! make_nonzero_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.get()))) + Expr::from(ScalarValue::$SCALAR(Some(self.get()))) } } }; @@ -104,7 +111,7 @@ macro_rules! make_timestamp_literal { #[doc = $DOC] impl TimestampLiteral for $TYPE { fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( + Expr::from(ScalarValue::TimestampNanosecond( Some((self.clone()).into()), None, )) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index da2a96327ce5..fea34c5c60d3 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -225,12 +225,14 @@ impl LogicalPlanBuilder { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in field_types.iter().enumerate() { - if let Expr::Literal(ScalarValue::Null) = row[j] { - row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); - } else { - row[j] = - std::mem::take(&mut row[j]).cast_to(field_type, &empty_schema)?; - } + row[j] = match &row[j] { + Expr::Literal(scalar) if scalar.value().is_null() => { + Expr::from(ScalarValue::try_from(field_type)?) + } + _ => { + std::mem::take(&mut row[j]).cast_to(field_type, &empty_schema)? + } + }; } } let fields = field_types diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 0292274e57ee..6cd130ba7703 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1448,7 +1448,7 @@ impl LogicalPlan { e.infer_placeholder_types(&schema)?.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::from(value))) } else { Ok(Transformed::no(e)) } @@ -4024,7 +4024,7 @@ digraph { let col = schema.field_names()[0].clone(); let filter = Filter::try_new( - Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + Expr::Column(col.into()).eq(Expr::from(ScalarValue::Int32(Some(1)))), scan, ) .unwrap(); diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index cc245b3572ec..282fa2b95be5 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -316,7 +316,7 @@ impl AggregateUDFImpl for Count { .as_any() .downcast_ref::() { - if lit_expr.value() == &COUNT_STAR_EXPANSION { + if lit_expr.scalar().value() == &COUNT_STAR_EXPANSION { return Some(ScalarValue::Int64(Some(num_rows as i64))); } } diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index bbfe56914c91..399303e7e77b 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -102,7 +102,7 @@ impl AggregateUDFImpl for NthValueAgg { let n = match acc_args.exprs[1] .as_any() .downcast_ref::() - .map(|lit| lit.value()) + .map(|lit| lit.scalar().value()) { Some(ScalarValue::Int64(Some(value))) => { if acc_args.is_reversed { diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index a7e9a37e23ad..1b3654547ef6 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -84,7 +84,7 @@ impl AggregateUDFImpl for StringAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { - return match lit.value() { + return match lit.scalar().value() { ScalarValue::Utf8(Some(delimiter)) | ScalarValue::LargeUtf8(Some(delimiter)) => { Ok(Box::new(StringAggAccumulator::new(delimiter.as_str()))) diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 24e892f8b715..e97859100058 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -58,8 +58,8 @@ fn criterion_benchmark(c: &mut Criterion) { let values = values(&mut rng); let mut buffer = Vec::new(); for i in 0..1000 { - buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + buffer.push(Expr::from(ScalarValue::Utf8(Some(keys[i].clone())))); + buffer.push(Expr::from(ScalarValue::Int32(Some(values[i])))); } let planner = NestedFunctionPlanner {}; diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index a1b74228a503..0b6f506d447f 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -138,7 +138,15 @@ fn data_type_from_args(args: &[Expr]) -> Result { if args.len() != 2 { return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } - let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { + + let Expr::Literal(scalar) = &args[1] else { + return plan_err!( + "arrow_cast requires its second argument to be a scalar, got {:?}", + &args[1] + ); + }; + + let ScalarValue::Utf8(Some(val)) = scalar.value() else { return plan_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", &args[1] diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 6f863809573b..0ff6067dd37e 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -127,7 +127,7 @@ impl ScalarUDFImpl for GetFieldFunc { } }; let data_type = args[0].get_type(schema)?; - match (data_type, name) { + match (data_type, name.value()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 9c3a3664ca32..b367f20ca126 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -151,7 +151,11 @@ impl ScalarUDFImpl for NamedStructFunc { let name = &chunk[0]; let value = &chunk[1]; - if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name { + let Expr::Literal(scalar) = name else { + return exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2) + }; + + if let ScalarValue::Utf8(Some(name)) = scalar.value() { Ok(Field::new(name, value.get_type(schema)?, true)) } else { exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2) diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 8b180ff41b91..10bc9dabf90f 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -91,7 +91,7 @@ impl ScalarUDFImpl for CurrentDateFunc { .unwrap() .num_days_from_ce(), ); - Ok(ExprSimplifyResult::Simplified(Expr::Literal( + Ok(ExprSimplifyResult::Simplified(Expr::from( ScalarValue::Date32(days), ))) } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 803759d4e904..6872a13eebdb 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -80,7 +80,7 @@ impl ScalarUDFImpl for CurrentTimeFunc { ) -> Result { let now_ts = info.execution_props().query_execution_start_time; let nano = now_ts.timestamp_nanos_opt().map(|ts| ts % 86400000000000); - Ok(ExprSimplifyResult::Simplified(Expr::Literal( + Ok(ExprSimplifyResult::Simplified(Expr::from( ScalarValue::Time64Nanosecond(nano), ))) } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b2221215b94b..80d2e4966f81 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -80,7 +80,7 @@ impl ScalarUDFImpl for NowFunc { .execution_props() .query_execution_start_time .timestamp_nanos_opt(); - Ok(ExprSimplifyResult::Simplified(Expr::Literal( + Ok(ExprSimplifyResult::Simplified(Expr::from( ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), ))) } diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 1ffda8759d7b..d79feb79ae35 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -219,7 +219,9 @@ impl ScalarUDFImpl for LogFunc { }; match number { - Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { + Expr::Literal(scalar) + if scalar.value() == &ScalarValue::new_one(&number_datatype)? => + { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( &info.get_data_type(&base)?, )?))) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 831f983d5916..44d6a20d96b1 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -147,12 +147,16 @@ impl ScalarUDFImpl for PowerFunc { let exponent_type = info.get_data_type(&exponent)?; match exponent { - Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => { - Ok(ExprSimplifyResult::Simplified(Expr::Literal( + Expr::Literal(scalar) + if scalar.value() == &ScalarValue::new_zero(&exponent_type)? => + { + Ok(ExprSimplifyResult::Simplified(Expr::from( ScalarValue::new_one(&info.get_data_type(&base)?)?, ))) } - Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { + Expr::Literal(scalar) + if scalar.value() == &ScalarValue::new_one(&exponent_type)? => + { Ok(ExprSimplifyResult::Simplified(base)) } Expr::ScalarFunction(ScalarFunction { func, mut args }) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index e854ff375503..7a2426606b7c 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -286,18 +286,20 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { + Expr::Literal(scalar) => match scalar.value() { + // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} + ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), - ) => contiguous_scalar += &v, - Expr::Literal(x) => { + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)) + => contiguous_scalar += v, + x => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." ) } + } // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` (if it is not empty) and reset it to empty string. // Then pushing this arg to the `new_args`. diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 761dddd1047b..26e77ff593af 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -323,24 +323,25 @@ fn get_concat_ws_doc() -> &'static Documentation { fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { - Expr::Literal( + Expr::Literal(scalar) => match scalar.value() { ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter) - | ScalarValue::Utf8View(delimiter), - ) => { - match delimiter { - // when the delimiter is an empty string, - // we can use `concat` to replace `concat_ws` - Some(delimiter) if delimiter.is_empty() => simplify_concat(args.to_vec()), - Some(delimiter) => { - let mut new_args = Vec::with_capacity(args.len()); - new_args.push(lit(delimiter)); - let mut contiguous_scalar = None; - for arg in args { - match arg { - // filter out null args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { + | ScalarValue::Utf8View(delimiter) => { + match delimiter { + // when the delimiter is an empty string, + // we can use `concat` to replace `concat_ws` + Some(delimiter) if delimiter.is_empty() => { + simplify_concat(args.to_vec()) + } + Some(delimiter) => { + let mut new_args = Vec::with_capacity(args.len()); + new_args.push(lit(delimiter)); + let mut contiguous_scalar = None; + for arg in args { + match arg { + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None) => {} + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { @@ -350,7 +351,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result return internal_err!("The scalar {s} should be casted to string type during the type coercion."), + s => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), + } // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` and reset it to None. // Then pushing this arg to the `new_args`. @@ -362,27 +364,28 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Ok(ExprSimplifyResult::Simplified(Expr::from( + ScalarValue::Utf8(None), + ))), } - // if the delimiter is null, then the value of the whole expression is null. - None => Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::Utf8(None), - ))), } - } - Expr::Literal(d) => internal_err!( + d => internal_err!( "The scalar {d} should be casted to string type during the type coercion." ), + }, _ => { let mut args = args .iter() @@ -397,7 +400,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v) => v.value().is_null(), _ => false, } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 7f918c03e3ac..154b3faf4711 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -31,7 +31,9 @@ use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; -use datafusion_expr::{expr, lit, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{ + expr, lit, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, Scalar, +}; use datafusion_physical_expr::execution_props::ExecutionProps; /// This struct rewrite the sub query plan by pull up the correlated @@ -433,9 +435,9 @@ fn agg_exprs_evaluation_result_on_empty_batch( let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { if func.name() == "count" { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + Transformed::yes(Expr::from(ScalarValue::Int64(Some(0)))) } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::from(ScalarValue::Null)) } } _ => Transformed::no(expr), @@ -449,7 +451,13 @@ fn agg_exprs_evaluation_result_on_empty_batch( let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; - if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) { + if matches!( + result_expr, + Expr::Literal(Scalar { + value: ScalarValue::Int64(_), + .. + }) + ) { expr_result_map_for_count_bug .insert(e.schema_name().to_string(), result_expr); } @@ -525,10 +533,19 @@ fn filter_exprs_evaluation_result_on_empty_batch( let result_expr = simplifier.simplify(result_expr)?; match &result_expr { // evaluate to false or null on empty batch, no need to pull up - Expr::Literal(ScalarValue::Null) - | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, + Expr::Literal(Scalar { + value: ScalarValue::Null, + .. + }) + | Expr::Literal(Scalar { + value: ScalarValue::Boolean(Some(false)), + .. + }) => None, // evaluate to true on empty batch, need to pull up the expr - Expr::Literal(ScalarValue::Boolean(Some(true))) => { + Expr::Literal(Scalar { + value: ScalarValue::Boolean(Some(true)), + .. + }) => { for (name, exprs) in input_expr_result_map_for_count_bug { expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); } @@ -543,7 +560,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( Box::new(result_expr.clone()), Box::new(input_expr.clone()), )], - else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + else_expr: Some(Box::new(Expr::from(ScalarValue::Null))), }); let expr_key = new_expr.schema_name().to_string(); expr_result_map_for_count_bug.insert(expr_key, new_expr); diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 4ed2ac8ba1a4..a23e878c6aa5 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -19,7 +19,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan}; +use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan, Scalar}; use std::sync::Arc; use crate::optimizer::ApplyOrder; @@ -60,7 +60,11 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v)), + predicate: + Expr::Literal(Scalar { + value: ScalarValue::Boolean(v), + .. + }), input, .. }) => match v { @@ -111,7 +115,7 @@ mod tests { #[test] fn filter_null() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + let filter_expr = Expr::from(ScalarValue::Boolean(None)); let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index f9b79e036f9b..9f4fe7bb110f 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -53,20 +53,23 @@ impl OptimizerRule for EliminateJoin { ) -> Result> { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { - match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(true)))) => { - Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin { - left: join.left, - right: join.right, - schema: join.schema, - }))) - } - Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( - Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: join.schema, - })), - ), + match &join.filter { + Some(Expr::Literal(scalar)) => match scalar.value() { + ScalarValue::Boolean(Some(true)) => { + Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin { + left: join.left, + right: join.right, + schema: join.schema, + }))) + } + ScalarValue::Boolean(Some(false)) => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: join.schema, + }), + )), + _ => Ok(Transformed::no(LogicalPlan::Join(join))), + }, _ => Ok(Transformed::no(LogicalPlan::Join(join))), } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 6409bb9e03f7..d3a5cfc46740 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -357,7 +357,7 @@ fn build_join( ), ( Box::new(Expr::Not(Box::new(filter.clone()))), - Box::new(Expr::Literal(ScalarValue::Null)), + Box::new(Expr::from(ScalarValue::Null)), ), ], else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 67d6bf8977a4..ef9c8480dc86 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,13 +32,16 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery, WindowFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; +use datafusion_expr::{ + expr::{InList, InSubquery, WindowFunction}, + Scalar, +}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; use crate::analyzer::type_coercion::TypeCoercionRewriter; @@ -477,9 +480,9 @@ struct ConstEvaluator<'a> { /// The simplify result of ConstEvaluator enum ConstSimplifyResult { // Expr was simplified and contains the new expression - Simplified(ScalarValue), + Simplified(Scalar), // Expr was not simplified and original value is returned - NotSimplified(ScalarValue), + NotSimplified(Scalar), // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } @@ -643,19 +646,20 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else if as_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::List( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::List(a.as_list::().to_owned().into()).into(), + ) } else if as_large_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::LargeList( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::LargeList(a.as_list::().to_owned().into()) + .into(), + ) } else { // Non-ListArray - match ScalarValue::try_from_array(&a, 0) { + match Scalar::try_from_array(&a, 0) { Ok(s) => { // TODO: support the optimization for `Map` type after support impl hash for it - if matches!(&s, ScalarValue::Map(_)) { + if matches!(&s.value(), ScalarValue::Map(_)) { ConstSimplifyResult::SimplifyRuntimeError( DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()), expr, @@ -679,7 +683,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s.into_value()) + ConstSimplifyResult::Simplified(s) } } } @@ -1061,7 +1065,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( + Transformed::yes(Expr::from(ScalarValue::new_zero( &info.get_data_type(&left)?, )?)) } @@ -1072,7 +1076,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( + Transformed::yes(Expr::from(ScalarValue::new_zero( &info.get_data_type(&left)?, )?)) } @@ -1147,7 +1151,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + Transformed::yes(Expr::from(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } @@ -1158,7 +1162,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + Transformed::yes(Expr::from(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } @@ -1233,7 +1237,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + Transformed::yes(Expr::from(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } @@ -1244,7 +1248,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + Transformed::yes(Expr::from(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } @@ -1257,7 +1261,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + Expr::from(ScalarValue::new_zero(&info.get_data_type(&right)?)?) } else { expr }) @@ -1271,7 +1275,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::from(ScalarValue::new_zero(&info.get_data_type(&left)?)?) } else { expr }) @@ -1452,7 +1456,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if !is_null(&expr) && matches!( pattern.as_ref(), - Expr::Literal(ScalarValue::Utf8(Some(pattern_str))) if pattern_str == "%" + Expr::Literal(Scalar{ value: ScalarValue::Utf8(Some(pattern_str)), ..}) if pattern_str == "%" ) => { Transformed::yes(lit(!negated)) @@ -1476,9 +1480,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { expr, list, negated, - }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { - Transformed::yes(lit(negated)) - } + }) if list.is_empty() && !is_null(&expr) => Transformed::yes(lit(negated)), // null in (x, y, z) --> null // null not in (x, y, z) --> null @@ -2032,7 +2034,7 @@ mod tests { #[test] fn test_simplify_multiply_by_null() { - let null = Expr::Literal(ScalarValue::Null); + let null = Expr::from(ScalarValue::Null); // A * null --> null { let expr = col("c2") * null.clone(); diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index afcbe528083b..8047485c1c7a 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -90,7 +90,10 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { high.as_ref(), ) { let expr_interval = NullableInterval::NotNull { - values: Interval::try_new(low.clone(), high.clone())?, + values: Interval::try_new( + low.value().clone(), + high.value().clone(), + )?, }; let contains = expr_interval.contains(*interval)?; @@ -116,7 +119,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { if let Expr::Literal(value) = left.as_ref() { - Some(Cow::Owned(value.clone().into())) + Some(Cow::Owned(value.value().clone().into())) } else { None } @@ -127,7 +130,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { if let Expr::Literal(value) = right.as_ref() { - Some(Cow::Owned(value.clone().into())) + Some(Cow::Owned(value.value().clone().into())) } else { None } @@ -169,9 +172,9 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { .iter() .filter_map(|expr| { if let Expr::Literal(item) = expr { - match interval - .contains(NullableInterval::from(item.clone())) - { + match interval.contains(NullableInterval::from( + item.value().clone(), + )) { // If we know for certain the value isn't in the column's interval, // we can skip checking it. Ok(interval) if interval.is_certainly_false() => None, @@ -417,7 +420,7 @@ mod tests { let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); let output = col("x").rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, Expr::Literal(scalar.clone())); + assert_eq!(output, Expr::from(scalar.clone())); } } diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 6c99f18ab0f6..f85d955eadb7 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -16,7 +16,7 @@ // under the License. use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{lit, BinaryExpr, Expr, Like, Operator}; +use datafusion_expr::{lit, BinaryExpr, Expr, Like, Operator, Scalar}; use regex_syntax::hir::{Capture, Hir, HirKind, Literal, Look}; /// Maximum number of regex alternations (`foo|bar|...`) that will be expanded into multiple `LIKE` expressions. @@ -42,7 +42,11 @@ pub fn simplify_regex_expr( ) -> Result { let mode = OperatorMode::new(&op); - if let Expr::Literal(ScalarValue::Utf8(Some(pattern))) = right.as_ref() { + if let Expr::Literal(Scalar { + value: ScalarValue::Utf8(Some(pattern)), + .. + }) = right.as_ref() + { match regex_syntax::Parser::new().parse(pattern) { Ok(hir) => { let kind = hir.kind(); @@ -100,7 +104,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), + pattern: Box::new(Expr::from(ScalarValue::from(pattern))), escape_char: None, case_insensitive: self.i, }; diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 38bfc1a93403..f9250fdd0914 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -134,47 +134,56 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> pub fn is_zero(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(0))) - | Expr::Literal(ScalarValue::Int16(Some(0))) - | Expr::Literal(ScalarValue::Int32(Some(0))) - | Expr::Literal(ScalarValue::Int64(Some(0))) - | Expr::Literal(ScalarValue::UInt8(Some(0))) - | Expr::Literal(ScalarValue::UInt16(Some(0))) - | Expr::Literal(ScalarValue::UInt32(Some(0))) - | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true, + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Int8(Some(0)) + | ScalarValue::Int16(Some(0)) + | ScalarValue::Int32(Some(0)) + | ScalarValue::Int64(Some(0)) + | ScalarValue::UInt8(Some(0)) + | ScalarValue::UInt16(Some(0)) + | ScalarValue::UInt32(Some(0)) + | ScalarValue::UInt64(Some(0)) => true, + ScalarValue::Float32(Some(v)) if *v == 0. => true, + ScalarValue::Float64(Some(v)) if *v == 0. => true, + ScalarValue::Decimal128(Some(v), _p, _s) if *v == 0 => true, + _ => false, + }, _ => false, } } pub fn is_one(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(1))) - | Expr::Literal(ScalarValue::Int16(Some(1))) - | Expr::Literal(ScalarValue::Int32(Some(1))) - | Expr::Literal(ScalarValue::Int64(Some(1))) - | Expr::Literal(ScalarValue::UInt8(Some(1))) - | Expr::Literal(ScalarValue::UInt16(Some(1))) - | Expr::Literal(ScalarValue::UInt32(Some(1))) - | Expr::Literal(ScalarValue::UInt64(Some(1))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s)) => { - *s >= 0 - && POWS_OF_TEN - .get(*s as usize) - .map(|x| x == v) - .unwrap_or_default() - } + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Int8(Some(1)) + | ScalarValue::Int16(Some(1)) + | ScalarValue::Int32(Some(1)) + | ScalarValue::Int64(Some(1)) + | ScalarValue::UInt8(Some(1)) + | ScalarValue::UInt16(Some(1)) + | ScalarValue::UInt32(Some(1)) + | ScalarValue::UInt64(Some(1)) => true, + ScalarValue::Float32(Some(v)) if *v == 1. => true, + ScalarValue::Float64(Some(v)) if *v == 1. => true, + ScalarValue::Decimal128(Some(v), _p, s) => { + *s >= 0 + && POWS_OF_TEN + .get(*s as usize) + .map(|x| x == v) + .unwrap_or_default() + } + _ => false, + }, _ => false, } } pub fn is_true(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => *v, + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Boolean(Some(v)) => *v, + _ => false, + }, _ => false, } } @@ -182,24 +191,27 @@ pub fn is_true(expr: &Expr) -> bool { /// returns true if expr is a /// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise pub fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) + matches!(expr, Expr::Literal(scalar) if matches!(scalar.value(), ScalarValue::Boolean(_))) } /// Return a literal NULL value of Boolean data type pub fn lit_bool_null() -> Expr { - Expr::Literal(ScalarValue::Boolean(None)) + Expr::from(ScalarValue::Boolean(None)) } pub fn is_null(expr: &Expr) -> bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v) => v.value().is_null(), _ => false, } } pub fn is_false(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v), + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Boolean(Some(v)) => !(*v), + _ => false, + }, _ => false, } } @@ -223,7 +235,10 @@ pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { /// `Expr::Literal(ScalarValue::Boolean(v))`. pub fn as_bool_lit(expr: &Expr) -> Result> { match expr { - Expr::Literal(ScalarValue::Boolean(v)) => Ok(*v), + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Boolean(v) => Ok(*v), + _ => internal_err!("Expected boolean literal, got {expr:?}"), + }, _ => internal_err!("Expected boolean literal, got {expr:?}"), } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 22e3c0ddd076..779d1e8950b0 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -33,7 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan}; +use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Scalar}; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -314,14 +314,14 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { /// Convert a literal value from one data type to another fn try_cast_literal_to_type( - lit_value: &ScalarValue, + lit_value: &Scalar, target_type: &DataType, ) -> Option { let lit_data_type = lit_value.data_type(); - if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { + if !is_supported_type(lit_data_type) || !is_supported_type(target_type) { return None; } - if lit_value.is_null() { + if lit_value.value().is_null() { // null value can be cast to any type of null value return ScalarValue::try_from(target_type).ok(); } @@ -332,11 +332,11 @@ fn try_cast_literal_to_type( /// Convert a numeric value from one numeric data type to another fn try_cast_numeric_literal( - lit_value: &ScalarValue, + lit_value: &Scalar, target_type: &DataType, ) -> Option { let lit_data_type = lit_value.data_type(); - if !is_supported_numeric_type(&lit_data_type) + if !is_supported_numeric_type(lit_data_type) || !is_supported_numeric_type(target_type) { return None; @@ -374,7 +374,7 @@ fn try_cast_numeric_literal( ), _ => return None, }; - let lit_value_target_type = match lit_value { + let lit_value_target_type = match lit_value.value() { ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), @@ -426,7 +426,7 @@ fn try_cast_numeric_literal( DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), DataType::Timestamp(TimeUnit::Second, tz) => { let value = cast_between_timestamp( - &lit_data_type, + lit_data_type, &DataType::Timestamp(TimeUnit::Second, tz.clone()), value, ); @@ -434,7 +434,7 @@ fn try_cast_numeric_literal( } DataType::Timestamp(TimeUnit::Millisecond, tz) => { let value = cast_between_timestamp( - &lit_data_type, + lit_data_type, &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), value, ); @@ -442,7 +442,7 @@ fn try_cast_numeric_literal( } DataType::Timestamp(TimeUnit::Microsecond, tz) => { let value = cast_between_timestamp( - &lit_data_type, + lit_data_type, &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), value, ); @@ -450,7 +450,7 @@ fn try_cast_numeric_literal( } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { let value = cast_between_timestamp( - &lit_data_type, + lit_data_type, &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), value, ); @@ -472,10 +472,10 @@ fn try_cast_numeric_literal( } fn try_cast_string_literal( - lit_value: &ScalarValue, + lit_value: &Scalar, target_type: &DataType, ) -> Option { - let string_value = match lit_value { + let string_value = match lit_value.value() { ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { s.clone() } @@ -492,11 +492,11 @@ fn try_cast_string_literal( /// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary fn try_cast_dictionary( - lit_value: &ScalarValue, + lit_value: &Scalar, target_type: &DataType, ) -> Option { let lit_value_type = lit_value.data_type(); - let result_scalar = match (lit_value, target_type) { + let result_scalar = match (lit_value.value(), target_type) { // Unwrap dictionary when inner type matches target type (ScalarValue::Dictionary(_, inner_value), _) if inner_value.data_type() == *target_type => @@ -505,9 +505,12 @@ fn try_cast_dictionary( } // Wrap type when target type is dictionary (_, DataType::Dictionary(index_type, inner_type)) - if **inner_type == lit_value_type => + if inner_type.as_ref() == lit_value_type => { - ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) + ScalarValue::Dictionary( + index_type.clone(), + Box::new(lit_value.value().clone()), + ) } _ => { return None; @@ -1180,10 +1183,11 @@ mod tests { target_type: DataType, expected_result: ExpectedCast, ) { - let actual_value = try_cast_literal_to_type(&literal, &target_type); + let scalar = Scalar::from(literal); + let actual_value = try_cast_literal_to_type(&scalar, &target_type); println!("expect_cast: "); - println!(" {literal:?} --> {target_type:?}"); + println!(" {scalar:?} --> {target_type:?}"); println!(" expected_result: {expected_result:?}"); println!(" actual_result: {actual_value:?}"); @@ -1197,7 +1201,7 @@ mod tests { // Verify that calling the arrow // cast kernel yields the same results // input array - let literal_array = literal + let literal_array = scalar .to_array_of_size(1) .expect("Failed to convert to array of size"); let expected_array = expected_value @@ -1212,7 +1216,7 @@ mod tests { assert_eq!( &expected_array, &cast_array, - "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" + "Result of casting {scalar:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" ); // Verify that for timestamp types the timezones are the same @@ -1239,7 +1243,7 @@ mod tests { fn test_try_cast_literal_to_timestamp() { // same timestamp let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), + &ScalarValue::TimestampNanosecond(Some(123456), None).into(), &DataType::Timestamp(TimeUnit::Nanosecond, None), ) .unwrap(); @@ -1251,7 +1255,7 @@ mod tests { // TimestampNanosecond to TimestampMicrosecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), + &ScalarValue::TimestampNanosecond(Some(123456), None).into(), &DataType::Timestamp(TimeUnit::Microsecond, None), ) .unwrap(); @@ -1263,7 +1267,7 @@ mod tests { // TimestampNanosecond to TimestampMillisecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), + &ScalarValue::TimestampNanosecond(Some(123456), None).into(), &DataType::Timestamp(TimeUnit::Millisecond, None), ) .unwrap(); @@ -1272,7 +1276,7 @@ mod tests { // TimestampNanosecond to TimestampSecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), + &ScalarValue::TimestampNanosecond(Some(123456), None).into(), &DataType::Timestamp(TimeUnit::Second, None), ) .unwrap(); @@ -1281,7 +1285,7 @@ mod tests { // TimestampMicrosecond to TimestampNanosecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123), None), + &ScalarValue::TimestampMicrosecond(Some(123), None).into(), &DataType::Timestamp(TimeUnit::Nanosecond, None), ) .unwrap(); @@ -1293,7 +1297,7 @@ mod tests { // TimestampMicrosecond to TimestampMillisecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123), None), + &ScalarValue::TimestampMicrosecond(Some(123), None).into(), &DataType::Timestamp(TimeUnit::Millisecond, None), ) .unwrap(); @@ -1302,7 +1306,7 @@ mod tests { // TimestampMicrosecond to TimestampSecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123456789), None), + &ScalarValue::TimestampMicrosecond(Some(123456789), None).into(), &DataType::Timestamp(TimeUnit::Second, None), ) .unwrap(); @@ -1310,7 +1314,7 @@ mod tests { // TimestampMillisecond to TimestampNanosecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123), None), + &ScalarValue::TimestampMillisecond(Some(123), None).into(), &DataType::Timestamp(TimeUnit::Nanosecond, None), ) .unwrap(); @@ -1321,7 +1325,7 @@ mod tests { // TimestampMillisecond to TimestampMicrosecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123), None), + &ScalarValue::TimestampMillisecond(Some(123), None).into(), &DataType::Timestamp(TimeUnit::Microsecond, None), ) .unwrap(); @@ -1331,7 +1335,7 @@ mod tests { ); // TimestampMillisecond to TimestampSecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123456789), None), + &ScalarValue::TimestampMillisecond(Some(123456789), None).into(), &DataType::Timestamp(TimeUnit::Second, None), ) .unwrap(); @@ -1339,7 +1343,7 @@ mod tests { // TimestampSecond to TimestampNanosecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), + &ScalarValue::TimestampSecond(Some(123), None).into(), &DataType::Timestamp(TimeUnit::Nanosecond, None), ) .unwrap(); @@ -1350,7 +1354,7 @@ mod tests { // TimestampSecond to TimestampMicrosecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), + &ScalarValue::TimestampSecond(Some(123), None).into(), &DataType::Timestamp(TimeUnit::Microsecond, None), ) .unwrap(); @@ -1361,7 +1365,7 @@ mod tests { // TimestampSecond to TimestampMillisecond let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), + &ScalarValue::TimestampSecond(Some(123), None).into(), &DataType::Timestamp(TimeUnit::Millisecond, None), ) .unwrap(); @@ -1372,7 +1376,7 @@ mod tests { // overflow let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(i64::MAX), None), + &ScalarValue::TimestampSecond(Some(i64::MAX), None).into(), &DataType::Timestamp(TimeUnit::Millisecond, None), ) .unwrap(); diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 9eda1277c263..e75e458b399f 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -31,7 +31,7 @@ fn make_col(name: &str, index: usize) -> Arc { } fn make_lit_i32(n: i32) -> Arc { - Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) + Arc::new(Literal::from(ScalarValue::Int32(Some(n)))) } fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index c1851ddb22b5..8dc40c41fe9b 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -765,14 +765,14 @@ mod tests { #[test] fn test_contains_any() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + let lit_true = Arc::new(Literal::from(ScalarValue::Boolean(Some(true)))) as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + let lit_false = Arc::new(Literal::from(ScalarValue::Boolean(Some(false)))) as Arc; let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + Arc::new(Literal::from(ScalarValue::Int32(Some(2)))) as Arc; let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; let cls1 = diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 6a1268ef8cdb..5f18ffcda6e9 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -1391,7 +1391,10 @@ fn get_expr_properties( } else if let Some(literal) = expr.as_any().downcast_ref::() { Ok(ExprProperties { sort_properties: SortProperties::Singleton, - range: Interval::try_new(literal.value().clone(), literal.value().clone())?, + range: Interval::try_new( + literal.scalar().value().clone(), + literal.scalar().value().clone(), + )?, }) } else { // Find orderings of its children diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index dbae695abb97..be58178511bb 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -2450,7 +2450,7 @@ mod tests { literal: ScalarValue, expected: ArrayRef, ) -> Result<()> { - let lit = Arc::new(Literal::new(literal)); + let lit = Arc::new(Literal::from(literal)); let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?; let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ef8a284680e6..89e9d990e6a6 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -128,7 +128,7 @@ impl CaseExpr { // during SQL planning, but not necessarily for other use cases) let else_expr = match &else_expr { Some(e) => match e.as_any().downcast_ref::() { - Some(lit) if lit.value().is_null() => None, + Some(lit) if lit.scalar().value().is_null() => None, _ => else_expr, }, _ => else_expr, @@ -1094,7 +1094,7 @@ mod tests { .transform(|e| { let transformed = match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { + Some(lit_value) => match lit_value.scalar().value() { ScalarValue::Utf8(Some(str_value)) => { Some(lit(str_value.to_uppercase())) } @@ -1115,7 +1115,7 @@ mod tests { .transform_down(|e| { let transformed = match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { + Some(lit_value) => match lit_value.scalar().value() { ScalarValue::Utf8(Some(str_value)) => { Some(lit(str_value.to_uppercase())) } @@ -1182,7 +1182,7 @@ mod tests { } fn make_lit_i32(n: i32) -> Arc { - Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) + Arc::new(Literal::from(ScalarValue::Int32(Some(n)))) } fn generate_case_when_with_type_coercion( diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index e064abbca35c..42c592958125 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -28,7 +28,7 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, Scalar}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; @@ -36,24 +36,30 @@ use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value #[derive(Debug, PartialEq, Eq, Hash)] pub struct Literal { - value: ScalarValue, + scalar: Scalar, +} + +impl From for Literal { + fn from(value: ScalarValue) -> Self { + Self::new(Scalar::from(value)) + } } impl Literal { /// Create a literal value expression - pub fn new(value: ScalarValue) -> Self { - Self { value } + pub fn new(value: Scalar) -> Self { + Self { scalar: value } } /// Get the scalar value - pub fn value(&self) -> &ScalarValue { - &self.value + pub fn scalar(&self) -> &Scalar { + &self.scalar } } impl std::fmt::Display for Literal { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.value) + write!(f, "{}", self.scalar) } } @@ -64,15 +70,15 @@ impl PhysicalExpr for Literal { } fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.value.data_type()) + Ok(self.scalar.data_type().clone()) } fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(self.value.is_null()) + Ok(self.scalar.value().is_null()) } fn evaluate(&self, _batch: &RecordBatch) -> Result { - Ok(ColumnarValue::from(self.value.clone())) + Ok(ColumnarValue::Scalar(self.scalar.clone())) } fn children(&self) -> Vec<&Arc> { @@ -94,7 +100,10 @@ impl PhysicalExpr for Literal { fn get_properties(&self, _children: &[ExprProperties]) -> Result { Ok(ExprProperties { sort_properties: SortProperties::Singleton, - range: Interval::try_new(self.value().clone(), self.value().clone())?, + range: Interval::try_new( + self.scalar.value().clone(), + self.scalar.value().clone(), + )?, }) } } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index f05ac3624b8e..3f861013b345 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -182,7 +182,7 @@ impl ExprIntervalGraphNode { pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { let expr = Arc::clone(&node.expr); if let Some(literal) = expr.as_any().downcast_ref::() { - let value = literal.value(); + let value = literal.scalar().value(); Interval::try_new(value.clone(), value.clone()) .map(|interval| Self::new_with_interval(expr, interval)) } else { @@ -530,7 +530,7 @@ impl ExprIntervalGraph { /// let expr = Arc::new(BinaryExpr::new( /// Arc::new(Column::new("gnz", 0)), /// Operator::Plus, - /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + /// Arc::new(Literal::from(ScalarValue::Int32(Some(10)))), /// )); /// /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]); @@ -873,7 +873,7 @@ mod tests { let left_and_1 = Arc::new(BinaryExpr::new( Arc::clone(&left_col) as Arc, Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(5)))), )); let expr = Arc::new(BinaryExpr::new( left_and_1, @@ -1198,7 +1198,7 @@ mod tests { Arc::new(Column::new("b", 1)), )), Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), )); let right_expr = Arc::new(BinaryExpr::new( @@ -1244,7 +1244,7 @@ mod tests { Arc::new(Column::new("b", 1)), )), Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), )); let right_expr = Arc::new(BinaryExpr::new( @@ -1292,7 +1292,7 @@ mod tests { Arc::new(Column::new("b", 1)), )), Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), )); let right_expr = Arc::new(BinaryExpr::new( @@ -1339,7 +1339,7 @@ mod tests { Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), )), Operator::Plus, Arc::new(Column::new("b", 1)), @@ -1383,7 +1383,7 @@ mod tests { let expression = BinaryExpr::new( Arc::new(Column::new("ts_column", 0)), Operator::Plus, - Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))), + Arc::new(Literal::from(ScalarValue::new_interval_mdn(0, 1, 321))), ); let parent = Interval::try_new( // 15.10.2020 - 10:11:12.000_000_321 AM diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index cedf55bccbf2..49163fba60af 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -43,17 +43,17 @@ pub fn gen_conjunctive_numerical_expr( let left_and_1 = Arc::new(BinaryExpr::new( Arc::clone(&left_col), op_1, - Arc::new(Literal::new(a)), + Arc::new(Literal::from(a)), )); let left_and_2 = Arc::new(BinaryExpr::new( Arc::clone(&right_col), op_2, - Arc::new(Literal::new(b)), + Arc::new(Literal::from(b)), )); let right_and_1 = - Arc::new(BinaryExpr::new(left_col, op_3, Arc::new(Literal::new(c)))); + Arc::new(BinaryExpr::new(left_col, op_3, Arc::new(Literal::from(c)))); let right_and_2 = - Arc::new(BinaryExpr::new(right_col, op_4, Arc::new(Literal::new(d)))); + Arc::new(BinaryExpr::new(right_col, op_4, Arc::new(Literal::from(d)))); let (greater_op, less_op) = bounds; let left_expr = Arc::new(BinaryExpr::new(left_and_1, greater_op, left_and_2)); @@ -81,17 +81,17 @@ pub fn gen_conjunctive_temporal_expr( let left_and_1 = binary( Arc::clone(&left_col), op_1, - Arc::new(Literal::new(a)), + Arc::new(Literal::from(a)), schema, )?; let left_and_2 = binary( Arc::clone(&right_col), op_2, - Arc::new(Literal::new(b)), + Arc::new(Literal::from(b)), schema, )?; - let right_and_1 = binary(left_col, op_3, Arc::new(Literal::new(c)), schema)?; - let right_and_2 = binary(right_col, op_4, Arc::new(Literal::new(d)), schema)?; + let right_and_1 = binary(left_col, op_3, Arc::new(Literal::from(c)), schema)?; + let right_and_2 = binary(right_col, op_4, Arc::new(Literal::from(d)), schema)?; let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt, right_and_2)); Ok(Arc::new(BinaryExpr::new( diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c718e6b054ef..d4c91e8ef3b1 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -101,16 +101,16 @@ mod tests { #[test] fn test_physical_exprs_contains() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + let lit_true = Arc::new(Literal::from(ScalarValue::Boolean(Some(true)))) as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + let lit_false = Arc::new(Literal::from(ScalarValue::Boolean(Some(false)))) as Arc; let lit4 = - Arc::new(Literal::new(ScalarValue::Int32(Some(4)))) as Arc; + Arc::new(Literal::from(ScalarValue::Int32(Some(4)))) as Arc; let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + Arc::new(Literal::from(ScalarValue::Int32(Some(2)))) as Arc; let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))) as Arc; let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; @@ -136,14 +136,14 @@ mod tests { #[test] fn test_physical_exprs_equal() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + let lit_true = Arc::new(Literal::from(ScalarValue::Boolean(Some(true)))) as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + let lit_false = Arc::new(Literal::from(ScalarValue::Boolean(Some(false)))) as Arc; let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))) as Arc; let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + Arc::new(Literal::from(ScalarValue::Int32(Some(2)))) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; let vec1 = vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]; @@ -213,13 +213,13 @@ mod tests { #[test] fn test_deduplicate_physical_exprs() { - let lit_true = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + let lit_true = &(Arc::new(Literal::from(ScalarValue::Boolean(Some(true)))) as Arc); - let lit_false = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + let lit_false = &(Arc::new(Literal::from(ScalarValue::Boolean(Some(false)))) as Arc); - let lit4 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(4)))) + let lit4 = &(Arc::new(Literal::from(ScalarValue::Int32(Some(4)))) as Arc); - let lit2 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) + let lit2 = &(Arc::new(Literal::from(ScalarValue::Int32(Some(2)))) as Arc); let col_a_expr = &(Arc::new(Column::new("a", 0)) as Arc); let col_b_expr = &(Arc::new(Column::new("b", 1)) as Arc); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index bffc2c46fc1e..7e606fd5e867 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -124,7 +124,7 @@ pub fn create_physical_expr( match execution_props.get_var_provider(VarType::System) { Some(provider) => { let scalar_value = provider.get_value(variable_names.clone())?; - Ok(Arc::new(Literal::new(scalar_value))) + Ok(Arc::new(Literal::from(scalar_value))) } _ => plan_err!("No system variable provider found"), } @@ -132,7 +132,7 @@ pub fn create_physical_expr( match execution_props.get_var_provider(VarType::UserDefined) { Some(provider) => { let scalar_value = provider.get_value(variable_names.clone())?; - Ok(Arc::new(Literal::new(scalar_value))) + Ok(Arc::new(Literal::from(scalar_value))) } _ => plan_err!("No user defined variable provider found"), } @@ -168,7 +168,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::from(ScalarValue::Boolean(None)), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -176,7 +176,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::from(ScalarValue::Boolean(None)), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -346,7 +346,7 @@ pub fn create_physical_expr( list, negated, }) => match expr.as_ref() { - Expr::Literal(ScalarValue::Utf8(None)) => { + Expr::Literal(scalar) if scalar.value() == &ScalarValue::Utf8(None) => { Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index fbb59cc92fa0..1a92b0665444 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -156,7 +156,7 @@ impl LiteralGuarantee { builder.aggregate_multi_conjunct( col, guarantee, - literals.iter().map(|e| e.value()), + literals.iter().map(|e| e.scalar().value()), ) } else { // split disjunction: OR OR ... @@ -211,7 +211,7 @@ impl LiteralGuarantee { builder.aggregate_multi_conjunct( first_term.col, Guarantee::In, - terms.iter().map(|term| term.lit.value()), + terms.iter().map(|term| term.lit.scalar().value()), ) } else { // can't infer anything @@ -274,7 +274,7 @@ impl<'a> GuaranteeBuilder<'a> { self.aggregate_multi_conjunct( col_op_lit.col, col_op_lit.guarantee, - [col_op_lit.lit.value()], + [col_op_lit.lit.scalar().value()], ) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f9dd973c814e..daa5d7b81c58 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -2330,7 +2330,7 @@ mod tests { let col_a = col("a", &schema)?; let col_b = col("b", &schema)?; - let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + let const_expr = Arc::new(Literal::from(ScalarValue::Int32(Some(1)))); let groups = PhysicalGroupBy::new( vec![ @@ -2340,15 +2340,15 @@ mod tests { ], vec![ ( - Arc::new(Literal::new(ScalarValue::Float32(None))), + Arc::new(Literal::from(ScalarValue::Float32(None))), "a".to_string(), ), ( - Arc::new(Literal::new(ScalarValue::Float32(None))), + Arc::new(Literal::from(ScalarValue::Float32(None))), "b".to_string(), ), ( - Arc::new(Literal::new(ScalarValue::Int32(None))), + Arc::new(Literal::from(ScalarValue::Int32(None))), "const".to_string(), ), ], diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 417d2098b083..8d032aaad719 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -828,21 +828,21 @@ mod tests { Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::LtEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(53)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(53)))), )), Operator::And, Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), Operator::Eq, - Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(3)))), )), Operator::And, Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("c", 2)), Operator::LtEq, - Arc::new(Literal::new(ScalarValue::Float32(Some(1075.0)))), + Arc::new(Literal::from(ScalarValue::Float32(Some(1075.0)))), )), Operator::And, Arc::new(BinaryExpr::new( @@ -941,11 +941,11 @@ mod tests { Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Lt, - Arc::new(Literal::new(ScalarValue::Int32(Some(200)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(200)))), )), Operator::And, Arc::new(BinaryExpr::new( - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), Operator::LtEq, Arc::new(Column::new("b", 1)), )), @@ -996,11 +996,11 @@ mod tests { Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(200)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(200)))), )), Operator::And, Arc::new(BinaryExpr::new( - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))), Operator::LtEq, Arc::new(Column::new("b", 1)), )), @@ -1059,7 +1059,7 @@ mod tests { let predicate = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Lt, - Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(50)))), )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); @@ -1098,16 +1098,16 @@ mod tests { Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::LtEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(10)))), )), Operator::And, Arc::new(BinaryExpr::new( - Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(0)))), Operator::LtEq, Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Minus, - Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(5)))), )), )), )); @@ -1142,7 +1142,7 @@ mod tests { let predicate = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Eq, - Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(10)))), )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); @@ -1164,7 +1164,7 @@ mod tests { let predicate = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Eq, - Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(10)))), )); let filter = FilterExec::try_new(predicate, input)?; assert!(filter.with_default_selectivity(120).is_err()); @@ -1190,7 +1190,7 @@ mod tests { let predicate = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Eq, - Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), + Arc::new(Literal::from(ScalarValue::Decimal128(Some(10), 10, 10))), )); let filter = FilterExec::try_new(predicate, input)?; let statistics = filter.statistics()?; diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 48d648c89a35..f6446548ca40 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -2509,7 +2509,7 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(10)))), )) as Arc; let filter = JoinFilter::new( @@ -2548,7 +2548,7 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(10)))), )) as Arc; let filter = JoinFilter::new(filter_expression, column_indices, intermediate_schema); @@ -2633,7 +2633,7 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(9)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(9)))), )) as Arc; let filter = JoinFilter::new( @@ -2674,7 +2674,7 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(11)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(11)))), )) as Arc; let filter = @@ -2755,7 +2755,7 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(8)))), )) as Arc; let filter = JoinFilter::new( @@ -2797,7 +2797,7 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(8)))), )) as Arc; let filter = @@ -2884,7 +2884,7 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(13)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(13)))), )) as Arc; let filter = JoinFilter::new( @@ -2931,7 +2931,7 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(8)))), )) as Arc; let filter = diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 029003374acc..ebc8a7338d9a 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -825,13 +825,13 @@ mod tests { let left_filter = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(8)))), )) as Arc; // right.b2!=10 let right_filter = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 1)), Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(10)))), )) as Arc; // filter = left.b1!=8 and right.b2!=10 // after filter: @@ -1199,26 +1199,26 @@ mod tests { let left_mod = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::Modulo, - Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(3)))), )) as Arc; // left.b1 % 3 != 0 let left_filter = Arc::new(BinaryExpr::new( left_mod, Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(0)))), )) as Arc; // right.b2 % 5 let right_mod = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 1)), Operator::Modulo, - Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(5)))), )) as Arc; // right.b2 % 5 != 0 let right_filter = Arc::new(BinaryExpr::new( right_mod, Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + Arc::new(Literal::from(ScalarValue::Int32(Some(0)))), )) as Arc; // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0 let filter_expression = diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 50f6f4a93097..a8a6333f7cfe 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -1542,7 +1542,7 @@ mod tests { .unwrap(); let expressions = vec![PhysicalSortExpr { - expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + expr: Arc::new(Literal::from(ScalarValue::Int64(Some(1)))), options: SortOptions::default(), }]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 6aafaad0ad77..4a8788aeeaa5 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -34,7 +34,7 @@ use datafusion_common::{ exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, ReversedUDWF, WindowFrame, + BuiltInWindowFunction, PartitionEvaluator, ReversedUDWF, Scalar, WindowFrame, WindowFunctionDefinition, WindowUDF, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; @@ -168,7 +168,7 @@ fn window_expr_from_aggregate_expr( fn get_scalar_value_from_args( args: &[Arc], index: usize, -) -> Result> { +) -> Result> { Ok(if let Some(field) = args.get(index) { let tmp = field .as_any() @@ -176,7 +176,7 @@ fn get_scalar_value_from_args( .ok_or_else(|| DataFusionError::NotImplemented( format!("There is only support Literal types for field at idx: {index} in Window Function"), ))? - .value() + .scalar() .clone(); Some(tmp) } else { @@ -184,36 +184,36 @@ fn get_scalar_value_from_args( }) } -fn get_signed_integer(value: ScalarValue) -> Result { - if value.is_null() { +fn get_signed_integer(scalar: Scalar) -> Result { + if scalar.value().is_null() { return Ok(0); } - if !value.data_type().is_integer() { + if !scalar.data_type().is_integer() { return exec_err!("Expected an integer value"); } - value.cast_to(&DataType::Int64)?.try_into() + scalar.cast_to(&DataType::Int64)?.into_value().try_into() } -fn get_unsigned_integer(value: ScalarValue) -> Result { - if value.is_null() { +fn get_unsigned_integer(scalar: Scalar) -> Result { + if scalar.value().is_null() { return Ok(0); } - if !value.data_type().is_integer() { + if !scalar.data_type().is_integer() { return exec_err!("Expected an integer value"); } - value.cast_to(&DataType::UInt64)?.try_into() + scalar.cast_to(&DataType::UInt64)?.into_value().try_into() } fn get_casted_value( - default_value: Option, + default_value: Option, dtype: &DataType, ) -> Result { match default_value { - Some(v) if !v.data_type().is_null() => v.cast_to(dtype), + Some(v) if !v.data_type().is_null() => Ok(v.cast_to(dtype)?.into_value()), // If None or Null datatype _ => ScalarValue::try_from(dtype), } @@ -241,11 +241,11 @@ fn create_built_in_window_expr( ) })?; - if n.is_null() { + if n.value().is_null() { return exec_err!("NTILE requires a positive integer, but finds NULL"); } - if n.is_unsigned() { + if n.value().is_unsigned() { let n = get_unsigned_integer(n)?; Arc::new(Ntile::new(name, n, out_data_type)) } else { @@ -297,7 +297,7 @@ fn create_built_in_window_expr( .ok_or_else(|| { exec_datafusion_err!("Expected a signed integer literal for the second argument of nth_value, got {}", args[1]) })? - .value() + .scalar() .clone(), )?; Arc::new(NthValue::nth( diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 5051c8f9322f..936b2519c64d 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -43,6 +43,7 @@ json = ["serde", "serde_json", "pbjson"] arrow = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } object_store = { workspace = true } pbjson = { workspace = true, optional = true } prost = { workspace = true } diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index d1b4374fc0e7..1f2fc36c0981 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -43,6 +43,7 @@ use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, JoinSide, ScalarValue, Statistics, TableReference, }; +use datafusion_expr::Scalar; #[derive(Debug)] pub enum Error { @@ -355,6 +356,16 @@ impl TryFrom<&protobuf::Schema> for Schema { impl TryFrom<&protobuf::ScalarValue> for ScalarValue { type Error = Error; + fn try_from( + scalar: &protobuf::ScalarValue, + ) -> datafusion_common::Result { + Ok(Scalar::try_from(scalar)?.into_value()) + } +} + +impl TryFrom<&protobuf::ScalarValue> for Scalar { + type Error = Error; + fn try_from( scalar: &protobuf::ScalarValue, ) -> datafusion_common::Result { @@ -365,22 +376,22 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .as_ref() .ok_or_else(|| Error::required("value"))?; - Ok(match value { - Value::BoolValue(v) => Self::Boolean(Some(*v)), - Value::Utf8Value(v) => Self::Utf8(Some(v.to_owned())), - Value::Utf8ViewValue(v) => Self::Utf8View(Some(v.to_owned())), - Value::LargeUtf8Value(v) => Self::LargeUtf8(Some(v.to_owned())), - Value::Int8Value(v) => Self::Int8(Some(*v as i8)), - Value::Int16Value(v) => Self::Int16(Some(*v as i16)), - Value::Int32Value(v) => Self::Int32(Some(*v)), - Value::Int64Value(v) => Self::Int64(Some(*v)), - Value::Uint8Value(v) => Self::UInt8(Some(*v as u8)), - Value::Uint16Value(v) => Self::UInt16(Some(*v as u16)), - Value::Uint32Value(v) => Self::UInt32(Some(*v)), - Value::Uint64Value(v) => Self::UInt64(Some(*v)), - Value::Float32Value(v) => Self::Float32(Some(*v)), - Value::Float64Value(v) => Self::Float64(Some(*v)), - Value::Date32Value(v) => Self::Date32(Some(*v)), + let value = match value { + Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)), + Value::Utf8Value(v) => ScalarValue::Utf8(Some(v.to_owned())), + Value::Utf8ViewValue(v) => ScalarValue::Utf8View(Some(v.to_owned())), + Value::LargeUtf8Value(v) => ScalarValue::LargeUtf8(Some(v.to_owned())), + Value::Int8Value(v) => ScalarValue::Int8(Some(*v as i8)), + Value::Int16Value(v) => ScalarValue::Int16(Some(*v as i16)), + Value::Int32Value(v) => ScalarValue::Int32(Some(*v)), + Value::Int64Value(v) => ScalarValue::Int64(Some(*v)), + Value::Uint8Value(v) => ScalarValue::UInt8(Some(*v as u8)), + Value::Uint16Value(v) => ScalarValue::UInt16(Some(*v as u16)), + Value::Uint32Value(v) => ScalarValue::UInt32(Some(*v)), + Value::Uint64Value(v) => ScalarValue::UInt64(Some(*v)), + Value::Float32Value(v) => ScalarValue::Float32(Some(*v)), + Value::Float64Value(v) => ScalarValue::Float64(Some(*v)), + Value::Date32Value(v) => ScalarValue::Date32(Some(*v)), // ScalarValue::List is serialized using arrow IPC format Value::ListValue(v) | Value::FixedSizeListValue(v) @@ -474,18 +485,20 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let arr = record_batch.column(0); match value { Value::ListValue(_) => { - Self::List(arr.as_list::().to_owned().into()) + ScalarValue::List(arr.as_list::().to_owned().into()) } Value::LargeListValue(_) => { - Self::LargeList(arr.as_list::().to_owned().into()) - } - Value::FixedSizeListValue(_) => { - Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()) + ScalarValue::LargeList(arr.as_list::().to_owned().into()) } + Value::FixedSizeListValue(_) => ScalarValue::FixedSizeList( + arr.as_fixed_size_list().to_owned().into(), + ), Value::StructValue(_) => { - Self::Struct(arr.as_struct().to_owned().into()) + ScalarValue::Struct(arr.as_struct().to_owned().into()) + } + Value::MapValue(_) => { + ScalarValue::Map(arr.as_map().to_owned().into()) } - Value::MapValue(_) => Self::Map(arr.as_map().to_owned().into()), _ => unreachable!(), } } @@ -495,7 +508,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } Value::Decimal128Value(val) => { let array = vec_to_array(val.value.clone()); - Self::Decimal128( + ScalarValue::Decimal128( Some(i128::from_be_bytes(array)), val.p as u8, val.s as i8, @@ -503,22 +516,22 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } Value::Decimal256Value(val) => { let array = vec_to_array(val.value.clone()); - Self::Decimal256( + ScalarValue::Decimal256( Some(i256::from_be_bytes(array)), val.p as u8, val.s as i8, ) } - Value::Date64Value(v) => Self::Date64(Some(*v)), + Value::Date64Value(v) => ScalarValue::Date64(Some(*v)), Value::Time32Value(v) => { let time_value = v.value.as_ref().ok_or_else(|| Error::required("value"))?; match time_value { protobuf::scalar_time32_value::Value::Time32SecondValue(t) => { - Self::Time32Second(Some(*t)) + ScalarValue::Time32Second(Some(*t)) } protobuf::scalar_time32_value::Value::Time32MillisecondValue(t) => { - Self::Time32Millisecond(Some(*t)) + ScalarValue::Time32Millisecond(Some(*t)) } } } @@ -527,18 +540,24 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { v.value.as_ref().ok_or_else(|| Error::required("value"))?; match time_value { protobuf::scalar_time64_value::Value::Time64MicrosecondValue(t) => { - Self::Time64Microsecond(Some(*t)) + ScalarValue::Time64Microsecond(Some(*t)) } protobuf::scalar_time64_value::Value::Time64NanosecondValue(t) => { - Self::Time64Nanosecond(Some(*t)) + ScalarValue::Time64Nanosecond(Some(*t)) } } } - Value::IntervalYearmonthValue(v) => Self::IntervalYearMonth(Some(*v)), - Value::DurationSecondValue(v) => Self::DurationSecond(Some(*v)), - Value::DurationMillisecondValue(v) => Self::DurationMillisecond(Some(*v)), - Value::DurationMicrosecondValue(v) => Self::DurationMicrosecond(Some(*v)), - Value::DurationNanosecondValue(v) => Self::DurationNanosecond(Some(*v)), + Value::IntervalYearmonthValue(v) => ScalarValue::IntervalYearMonth(Some(*v)), + Value::DurationSecondValue(v) => ScalarValue::DurationSecond(Some(*v)), + Value::DurationMillisecondValue(v) => { + ScalarValue::DurationMillisecond(Some(*v)) + } + Value::DurationMicrosecondValue(v) => { + ScalarValue::DurationMicrosecond(Some(*v)) + } + Value::DurationNanosecondValue(v) => { + ScalarValue::DurationNanosecond(Some(*v)) + } Value::TimestampValue(v) => { let timezone = if v.timezone.is_empty() { None @@ -551,16 +570,16 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { match ts_value { protobuf::scalar_timestamp_value::Value::TimeMicrosecondValue(t) => { - Self::TimestampMicrosecond(Some(*t), timezone) + ScalarValue::TimestampMicrosecond(Some(*t), timezone) } protobuf::scalar_timestamp_value::Value::TimeNanosecondValue(t) => { - Self::TimestampNanosecond(Some(*t), timezone) + ScalarValue::TimestampNanosecond(Some(*t), timezone) } protobuf::scalar_timestamp_value::Value::TimeSecondValue(t) => { - Self::TimestampSecond(Some(*t), timezone) + ScalarValue::TimestampSecond(Some(*t), timezone) } protobuf::scalar_timestamp_value::Value::TimeMillisecondValue(t) => { - Self::TimestampMillisecond(Some(*t), timezone) + ScalarValue::TimestampMillisecond(Some(*t), timezone) } } } @@ -578,15 +597,18 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .as_ref() .try_into()?; - Self::Dictionary(Box::new(index_type), Box::new(value)) + ScalarValue::Dictionary( + Box::new(index_type), + Box::new(value.into_value()), + ) } - Value::BinaryValue(v) => Self::Binary(Some(v.clone())), - Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())), - Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), - Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some( + Value::BinaryValue(v) => ScalarValue::Binary(Some(v.clone())), + Value::BinaryViewValue(v) => ScalarValue::BinaryView(Some(v.clone())), + Value::LargeBinaryValue(v) => ScalarValue::LargeBinary(Some(v.clone())), + Value::IntervalDaytimeValue(v) => ScalarValue::IntervalDayTime(Some( IntervalDayTimeType::make_value(v.days, v.milliseconds), )), - Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some( + Value::IntervalMonthDayNano(v) => ScalarValue::IntervalMonthDayNano(Some( IntervalMonthDayNanoType::make_value(v.months, v.days, v.nanos), )), Value::UnionValue(val) => { @@ -612,19 +634,21 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let val = match &val.value { None => None, Some(val) => { - let val: ScalarValue = val + let val: Scalar = val .as_ref() .try_into() .map_err(|_| Error::General("Invalid Scalar".to_string()))?; - Some((v_id, Box::new(val))) + Some((v_id, Box::new(val.into_value()))) } }; - Self::Union(val, fields, mode) + ScalarValue::Union(val, fields, mode) } Value::FixedSizeBinaryValue(v) => { - Self::FixedSizeBinary(v.length, Some(v.clone().values)) + ScalarValue::FixedSizeBinary(v.length, Some(v.clone().values)) } - }) + }; + + Ok(Scalar::from(value)) } } @@ -711,8 +735,10 @@ impl From for Precision { match precision_type { protobuf::PrecisionInfo::Exact => { if let Some(val) = s.val { - if let Ok(ScalarValue::UInt64(Some(val))) = - ScalarValue::try_from(&val) + if let Ok(Scalar { + value: ScalarValue::UInt64(Some(val)), + .. + }) = Scalar::try_from(&val) { Precision::Exact(val as usize) } else { @@ -724,8 +750,10 @@ impl From for Precision { } protobuf::PrecisionInfo::Inexact => { if let Some(val) = s.val { - if let Ok(ScalarValue::UInt64(Some(val))) = - ScalarValue::try_from(&val) + if let Ok(Scalar { + value: ScalarValue::UInt64(Some(val)), + .. + }) = Scalar::try_from(&val) { Precision::Inexact(val as usize) } else { @@ -748,8 +776,8 @@ impl From for Precision { match precision_type { protobuf::PrecisionInfo::Exact => { if let Some(val) = s.val { - if let Ok(val) = ScalarValue::try_from(&val) { - Precision::Exact(val) + if let Ok(val) = Scalar::try_from(&val) { + Precision::Exact(val.into_value()) } else { Precision::Absent } @@ -759,8 +787,8 @@ impl From for Precision { } protobuf::PrecisionInfo::Inexact => { if let Some(val) = s.val { - if let Ok(val) = ScalarValue::try_from(&val) { - Precision::Inexact(val) + if let Ok(val) = Scalar::try_from(&val) { + Precision::Inexact(val.into_value()) } else { Precision::Absent } diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index ebb53ae7577c..1ecfc935bdd7 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -41,6 +41,7 @@ use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, JoinSide, ScalarValue, Statistics, }; +use datafusion_expr::Scalar; #[derive(Debug)] pub enum Error { @@ -290,12 +291,20 @@ impl TryFrom<&DFSchemaRef> for protobuf::DfSchema { } } +impl TryFrom<&Scalar> for protobuf::ScalarValue { + type Error = Error; + + fn try_from(scalar: &Scalar) -> Result { + scalar.value().try_into() + } +} + impl TryFrom<&ScalarValue> for protobuf::ScalarValue { type Error = Error; - fn try_from(val: &ScalarValue) -> Result { - let data_type = val.data_type(); - match val { + fn try_from(value: &ScalarValue) -> Result { + let data_type = value.data_type(); + match value { ScalarValue::Boolean(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::BoolValue(*s)) } @@ -358,19 +367,19 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } ScalarValue::List(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + encode_scalar_nested_value(arr.to_owned() as ArrayRef, value) } ScalarValue::LargeList(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + encode_scalar_nested_value(arr.to_owned() as ArrayRef, value) } ScalarValue::FixedSizeList(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + encode_scalar_nested_value(arr.to_owned() as ArrayRef, value) } ScalarValue::Struct(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + encode_scalar_nested_value(arr.to_owned() as ArrayRef, value) } ScalarValue::Map(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + encode_scalar_nested_value(arr.to_owned() as ArrayRef, value) } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 893255ccc8ce..339e578146b9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -24,7 +24,6 @@ use datafusion_common::{ }; use datafusion_expr::expr::{Alias, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; -use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ expr::{self, InList, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, @@ -33,6 +32,7 @@ use datafusion_expr::{ JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; +use datafusion_expr::{ExprFunctionExt, Scalar}; use datafusion_proto_common::{from_proto::FromOptionalField, FromProtoError as Error}; use crate::protobuf::plan_type::PlanTypeEnum::{ @@ -189,11 +189,11 @@ impl TryFrom for WindowFrameBound { match bound_type { protobuf::WindowFrameBoundType::CurrentRow => Ok(Self::CurrentRow), protobuf::WindowFrameBoundType::Preceding => match bound.bound_value { - Some(x) => Ok(Self::Preceding(ScalarValue::try_from(&x)?)), + Some(x) => Ok(Self::Preceding(Scalar::try_from(&x)?.into_value())), None => Ok(Self::Preceding(ScalarValue::UInt64(None))), }, protobuf::WindowFrameBoundType::Following => match bound.bound_value { - Some(x) => Ok(Self::Following(ScalarValue::try_from(&x)?)), + Some(x) => Ok(Self::Following(Scalar::try_from(&x)?.into_value())), None => Ok(Self::Following(ScalarValue::UInt64(None))), }, } @@ -258,7 +258,7 @@ pub fn parse_expr( } ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { - let scalar_value: ScalarValue = literal.try_into()?; + let scalar_value: Scalar = literal.try_into()?; Ok(Expr::Literal(scalar_value)) } ExprType::WindowExpr(expr) => { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 6f6065a1c284..fa35b19051bb 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -121,7 +121,7 @@ pub fn serialize_physical_window_expr( } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { args.insert( 0, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( + Arc::new(Literal::from(datafusion_common::ScalarValue::Int64(Some( ntile_expr.get_n() as i64, )))), ); @@ -131,13 +131,13 @@ pub fn serialize_physical_window_expr( { args.insert( 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( + Arc::new(Literal::from(datafusion_common::ScalarValue::Int64(Some( window_shift_expr.get_shift_offset(), )))), ); args.insert( 2, - Arc::new(Literal::new(window_shift_expr.get_default_value())), + Arc::new(Literal::from(window_shift_expr.get_default_value())), ); if window_shift_expr.get_shift_offset() >= 0 { @@ -152,7 +152,7 @@ pub fn serialize_physical_window_expr( NthValueKind::Nth(n) => { args.insert( 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64( + Arc::new(Literal::from(datafusion_common::ScalarValue::Int64( Some(n), ))), ); @@ -352,7 +352,7 @@ pub fn serialize_physical_expr( } else if let Some(lit) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( - lit.value().try_into()?, + lit.scalar().value().try_into()?, )), }) } else if let Some(cast) = expr.downcast_ref::() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index cd789e06dc3b..fe7c03177887 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1968,7 +1968,7 @@ fn roundtrip_case_with_null() { let test_expr = Expr::Case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(Expr::Literal(ScalarValue::Null))), + Some(Box::new(Expr::from(ScalarValue::Null))), )); let ctx = SessionContext::new(); @@ -1977,7 +1977,7 @@ fn roundtrip_case_with_null() { #[test] fn roundtrip_null_literal() { - let test_expr = Expr::Literal(ScalarValue::Null); + let test_expr = Expr::from(ScalarValue::Null); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 025676f790a8..87f9270c7d4c 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -1047,7 +1047,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "result".to_string(), ))); let aggr_args: Vec> = - vec![Arc::new(Literal::new(ScalarValue::from(42)))]; + vec![Arc::new(Literal::from(ScalarValue::from(42)))]; let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) .schema(Arc::clone(&schema)) diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d1b50105d053..9af9653ccd60 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -256,7 +256,7 @@ fn test_expression_serialization_roundtrip() { use datafusion_proto::logical_plan::from_proto::parse_expr; let ctx = SessionContext::new(); - let lit = Expr::Literal(ScalarValue::Utf8(None)); + let lit = Expr::from(ScalarValue::Utf8(None)); for function in string::functions() { // default to 4 args (though some exprs like substr have error checking) let num_args = 4; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 34e119c45fdf..b1f19ce8c61e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -184,7 +184,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::Extract { field, expr, .. } => { let mut extract_args = vec![ - Expr::Literal(ScalarValue::from(format!("{field}"))), + Expr::from(ScalarValue::from(format!("{field}"))), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index f58ab5ff3612..f0fe7d98fbf5 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -51,7 +51,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (None, Some(for_expr)) => { let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let from_logic = Expr::Literal(ScalarValue::Int64(Some(1))); + let from_logic = Expr::from(ScalarValue::Int64(Some(1))); let for_logic = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; vec![arg, from_logic, for_logic] diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index be0909b58468..062e20f2cabc 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -41,7 +41,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(lit(s)), - Value::Null => Ok(Expr::Literal(ScalarValue::Null)), + Value::Null => Ok(Expr::from(ScalarValue::Null)), Value::Boolean(n) => Ok(lit(n)), Value::Placeholder(param) => { Self::create_placeholder_expr(param, param_data_types) @@ -351,7 +351,7 @@ fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { )))); } - Ok(Expr::Literal(ScalarValue::Decimal128( + Ok(Expr::from(ScalarValue::Decimal128( Some(if negative { -number } else { number }), precision as u8, scale as i8, diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 71328cfd018c..a75e1436deef 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -177,7 +177,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// tracks a more general solution fn get_constant_result(expr: &Expr, arg_name: &str) -> Result { match expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => Ok(*s), + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Int64(Some(s)) => Ok(*s), + _ => plan_err!("Unexpected scalar value {scalar} in {arg_name} clause"), + }, Expr::BinaryExpr(binary_expr) => { let lhs = get_constant_result(&binary_expr.left, arg_name)?; let rhs = get_constant_result(&binary_expr.right, arg_name)?; diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 656d72d07ba2..365907976974 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1698,7 +1698,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .cloned() .unwrap_or_else(|| { // If there is no default for the column, then the default is NULL - Expr::Literal(ScalarValue::Null) + Expr::from(ScalarValue::Null) }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 537ac2274424..decf318f03d0 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -204,7 +204,7 @@ impl Unparser<'_> { }), } } - Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), + Expr::Literal(scalar) => Ok(self.scalar_to_sql(scalar.value())?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(WindowFunction { fun, @@ -1622,87 +1622,87 @@ mod tests { r#"a LIKE 'foo' ESCAPE 'o'"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(0))), + Expr::from(ScalarValue::Date64(Some(0))), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(10000))), + Expr::from(ScalarValue::Date64(Some(10000))), r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(-10000))), + Expr::from(ScalarValue::Date64(Some(-10000))), r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(0))), + Expr::from(ScalarValue::Date32(Some(0))), r#"CAST('1970-01-01' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(10))), + Expr::from(ScalarValue::Date32(Some(10))), r#"CAST('1970-01-11' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(-1))), + Expr::from(ScalarValue::Date32(Some(-1))), r#"CAST('1969-12-31' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)), + Expr::from(ScalarValue::TimestampSecond(Some(10001), None)), r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond( + Expr::from(ScalarValue::TimestampSecond( Some(10001), Some("+08:00".into()), )), r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)), + Expr::from(ScalarValue::TimestampMillisecond(Some(10001), None)), r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond( + Expr::from(ScalarValue::TimestampMillisecond( Some(10001), Some("+08:00".into()), )), r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), + Expr::from(ScalarValue::TimestampMicrosecond(Some(10001), None)), r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond( + Expr::from(ScalarValue::TimestampMicrosecond( Some(10001), Some("+08:00".into()), )), r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)), + Expr::from(ScalarValue::TimestampNanosecond(Some(10001), None)), r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond( + Expr::from(ScalarValue::TimestampNanosecond( Some(10001), Some("+08:00".into()), )), r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::Time32Second(Some(10001))), + Expr::from(ScalarValue::Time32Second(Some(10001))), r#"CAST('02:46:41' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))), + Expr::from(ScalarValue::Time32Millisecond(Some(10001))), r#"CAST('00:00:10.001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))), + Expr::from(ScalarValue::Time64Microsecond(Some(10001))), r#"CAST('00:00:00.010001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))), + Expr::from(ScalarValue::Time64Nanosecond(Some(10001))), r#"CAST('00:00:00.000010001' AS TIME)"#, ), (sum(col("a")), r#"sum(a)"#), @@ -1837,7 +1837,7 @@ mod tests { (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), // See test_interval_scalar_to_expr for interval literals ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( + (col("a") + col("b")).gt(Expr::from(ScalarValue::Decimal128( Some(100123), 28, 3, @@ -1845,7 +1845,7 @@ mod tests { r#"((a + b) > 100.123)"#, ), ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( + (col("a") + col("b")).gt(Expr::from(ScalarValue::Decimal256( Some(100123.into()), 28, 3, @@ -2121,13 +2121,10 @@ mod tests { #[test] fn test_float_scalar_to_expr() { let tests = [ - (Expr::Literal(ScalarValue::Float64(Some(3f64))), "3.0"), - (Expr::Literal(ScalarValue::Float64(Some(3.1f64))), "3.1"), - (Expr::Literal(ScalarValue::Float32(Some(-2f32))), "-2.0"), - ( - Expr::Literal(ScalarValue::Float32(Some(-2.989f32))), - "-2.989", - ), + (Expr::from(ScalarValue::Float64(Some(3f64))), "3.0"), + (Expr::from(ScalarValue::Float64(Some(3.1f64))), "3.1"), + (Expr::from(ScalarValue::Float32(Some(-2f32))), "-2.0"), + (Expr::from(ScalarValue::Float32(Some(-2.989f32))), "-2.989"), ]; for (value, expected) in tests { let dialect = CustomDialectBuilder::new().build(); @@ -2215,7 +2212,7 @@ mod tests { let expr = ScalarUDF::new_from_impl( datafusion_functions::datetime::date_part::DatePartFunc::new(), ) - .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + .call(vec![Expr::from(ScalarValue::new_utf8(unit)), col("x")]); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{}", ast); @@ -2323,7 +2320,7 @@ mod tests { fn test_cast_value_to_dict_expr() { let tests = [( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + expr: Box::new(Expr::from(ScalarValue::Utf8(Some( "variation".to_string(), )))), data_type: DataType::Dictionary( @@ -2366,7 +2363,7 @@ mod tests { expr: Box::new(col("a")), data_type: DataType::Float64, }), - Expr::Literal(ScalarValue::Int64(Some(2))), + Expr::from(ScalarValue::Int64(Some(2))), ], }); let ast = unparser.expr_to_sql(&expr)?; diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index e05df8ba77fc..71a2f270ffb1 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -23,7 +23,7 @@ use datafusion_common::{ Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window, + utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Scalar, Window, }; use sqlparser::ast; @@ -209,7 +209,11 @@ pub(crate) fn date_part_to_sql( match (style, date_part_args.len()) { (DateFieldExtractStyle::Extract, 2) => { let date_expr = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(Scalar { + value: ScalarValue::Utf8(Some(field)), + .. + }) = &date_part_args[0] + { let field = match field.to_lowercase().as_str() { "year" => ast::DateTimeField::Year, "month" => ast::DateTimeField::Month, @@ -230,7 +234,11 @@ pub(crate) fn date_part_to_sql( (DateFieldExtractStyle::Strftime, 2) => { let column = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(Scalar { + value: ScalarValue::Utf8(Some(field)), + .. + }) = &date_part_args[0] + { let field = match field.to_lowercase().as_str() { "year" => "%Y", "month" => "%m", diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index d8ad964be213..212d4925f327 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -163,23 +163,26 @@ pub(crate) fn resolve_positions_to_exprs( expr: Expr, select_exprs: &[Expr], ) -> Result { - match expr { + match &expr { // sql_expr_to_logical_expr maps number to i64 // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 - Expr::Literal(ScalarValue::Int64(Some(position))) - if position > 0_i64 && position <= select_exprs.len() as i64 => - { - let index = (position - 1) as usize; - let select_expr = &select_exprs[index]; - Ok(match select_expr { - Expr::Alias(Alias { expr, .. }) => *expr.clone(), - _ => select_expr.clone(), - }) + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Int64(Some(position)) + if *position > 0_i64 && *position <= select_exprs.len() as i64 => + { + let index = (position - 1) as usize; + let select_expr = &select_exprs[index]; + Ok(match select_expr { + Expr::Alias(Alias { expr, .. }) => *expr.clone(), + _ => select_expr.clone(), + }) + } + ScalarValue::Int64(Some(position)) => plan_err!( + "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}", + position, select_exprs.len() + ), + _ => Ok(expr), } - Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!( - "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}", - position, select_exprs.len() - ), _ => Ok(expr), } } diff --git a/datafusion/sqllogictest/test_files/tpch/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/q16.slt.part index 8058371764f2..3009762a72d3 100644 --- a/datafusion/sqllogictest/test_files/tpch/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q16.slt.part @@ -88,7 +88,7 @@ physical_plan 21)----------------------------------CoalesceBatchesExec: target_batch_size=8192 22)------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 23)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) +24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { scalar: Int32(49) }, Literal { scalar: Int32(14) }, Literal { scalar: Int32(23) }, Literal { scalar: Int32(45) }, Literal { scalar: Int32(19) }, Literal { scalar: Int32(3) }, Literal { scalar: Int32(36) }, Literal { scalar: Int32(9) }]) 25)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 26)--------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], has_header=false 27)--------------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/q19.slt.part b/datafusion/sqllogictest/test_files/tpch/q19.slt.part index 70465ea065a1..73903da54a02 100644 --- a/datafusion/sqllogictest/test_files/tpch/q19.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q19.slt.part @@ -69,7 +69,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] +06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { scalar: Utf8("SM CASE") }, Literal { scalar: Utf8("SM BOX") }, Literal { scalar: Utf8("SM PACK") }, Literal { scalar: Utf8("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { scalar: Utf8("MED BAG") }, Literal { scalar: Utf8("MED BOX") }, Literal { scalar: Utf8("MED PKG") }, Literal { scalar: Utf8("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { scalar: Utf8("LG CASE") }, Literal { scalar: Utf8("LG BOX") }, Literal { scalar: Utf8("LG PACK") }, Literal { scalar: Utf8("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 09)----------------CoalesceBatchesExec: target_batch_size=8192 @@ -78,7 +78,7 @@ physical_plan 12)------------CoalesceBatchesExec: target_batch_size=8192 13)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 14)----------------CoalesceBatchesExec: target_batch_size=8192 -15)------------------FilterExec: (p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 +15)------------------FilterExec: (p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { scalar: Utf8("SM CASE") }, Literal { scalar: Utf8("SM BOX") }, Literal { scalar: Utf8("SM PACK") }, Literal { scalar: Utf8("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { scalar: Utf8("MED BAG") }, Literal { scalar: Utf8("MED BOX") }, Literal { scalar: Utf8("MED PKG") }, Literal { scalar: Utf8("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { scalar: Utf8("LG CASE") }, Literal { scalar: Utf8("LG BOX") }, Literal { scalar: Utf8("LG PACK") }, Literal { scalar: Utf8("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 16)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 17)----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/q22.slt.part index d2168b0136ba..79d233d4adde 100644 --- a/datafusion/sqllogictest/test_files/tpch/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q22.slt.part @@ -90,7 +90,7 @@ physical_plan 14)--------------------------CoalesceBatchesExec: target_batch_size=8192 15)----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------FilterExec: Use substr(c_phone@1, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]) +17)--------------------------------FilterExec: Use substr(c_phone@1, 1, 2) IN (SET) ([Literal { scalar: Utf8("13") }, Literal { scalar: Utf8("31") }, Literal { scalar: Utf8("23") }, Literal { scalar: Utf8("29") }, Literal { scalar: Utf8("30") }, Literal { scalar: Utf8("18") }, Literal { scalar: Utf8("17") }]) 18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 19)------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], has_header=false 20)--------------------------CoalesceBatchesExec: target_batch_size=8192 @@ -100,7 +100,7 @@ physical_plan 24)----------------------CoalescePartitionsExec 25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] 26)--------------------------CoalesceBatchesExec: target_batch_size=8192 -27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND Use substr(c_phone@0, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]), projection=[c_acctbal@1] +27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND Use substr(c_phone@0, 1, 2) IN (SET) ([Literal { scalar: Utf8("13") }, Literal { scalar: Utf8("31") }, Literal { scalar: Utf8("23") }, Literal { scalar: Utf8("29") }, Literal { scalar: Utf8("30") }, Literal { scalar: Utf8("18") }, Literal { scalar: Utf8("17") }]), projection=[c_acctbal@1] 28)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 29)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], has_header=false diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 030536f9f830..650b24d076bf 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -59,7 +59,8 @@ use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, - Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + Repartition, Scalar, Subquery, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -1213,7 +1214,7 @@ pub async fn from_substrait_agg_func( if let Ok(fun) = ctx.udaf(function_name) { // deal with situation that count(*) got no arguments let args = if fun.name() == "count" && args.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + vec![Expr::from(ScalarValue::Int64(Some(1)))] } else { args }; @@ -1870,7 +1871,7 @@ fn from_substrait_bound( pub(crate) fn from_substrait_literal_without_names( lit: &Literal, extensions: &Extensions, -) -> Result { +) -> Result { from_substrait_literal(lit, extensions, &vec![], &mut 0) } @@ -1879,7 +1880,7 @@ fn from_substrait_literal( extensions: &Extensions, dfs_names: &Vec, name_idx: &mut usize, -) -> Result { +) -> Result { let scalar_value = match &lit.literal_type { Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), Some(LiteralType::I8(n)) => match lit.type_variation_reference { @@ -2023,6 +2024,7 @@ fn from_substrait_literal( &mut element_name_idx, ) }) + .map(|el| el.map(|scalar| scalar.into_value())) .collect::>>()?; *name_idx = element_name_idx; if elements.is_empty() { @@ -2084,10 +2086,13 @@ fn from_substrait_literal( &mut entry_name_idx, )?; ScalarStructBuilder::new() - .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) .with_scalar( - Field::new("value", value_sv.data_type(), true), - value_sv, + Field::new("key", key_sv.data_type().clone(), false), + key_sv.into_value(), + ) + .with_scalar( + Field::new("value", value_sv.data_type().clone(), true), + value_sv.into_value(), ) .build() }) @@ -2147,7 +2152,10 @@ fn from_substrait_literal( let sv = from_substrait_literal(field, extensions, dfs_names, name_idx)?; // We assume everything to be nullable, since Arrow's strict about things matching // and it's hard to match otherwise. - builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); + builder = builder.with_scalar( + Field::new(name, sv.data_type().clone(), true), + sv.into_value(), + ); } builder.build()? } @@ -2279,7 +2287,7 @@ fn from_substrait_literal( _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), }; - Ok(scalar_value) + Ok(Scalar::from(scalar_value)) } fn from_substrait_null( @@ -2589,7 +2597,10 @@ impl BuiltinExprBuilder { .await?; match escape_char_expr { - Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { + Expr::Literal(Scalar { + value: ScalarValue::Utf8(escape_char_string), + .. + }) => { // Convert Option to Option escape_char_string.and_then(|s| s.chars().next()) } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1165ce13d236..a66e18c1522a 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -21,7 +21,7 @@ use substrait::proto::expression_reference::ExprType; use arrow_buffer::ToByteSlice; use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ - CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, + CrossJoin, Distinct, Like, Partitioning, Scalar, WindowFrameUnits, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -1668,7 +1668,7 @@ fn make_substrait_like_expr( let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extensions)?; let escape_char = to_substrait_literal_expr( - &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())).into(), extensions, )?; let arguments = vec![ @@ -1826,22 +1826,19 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal( - value: &ScalarValue, - extensions: &mut Extensions, -) -> Result { - if value.is_null() { +fn to_substrait_literal(scalar: &Scalar, extensions: &mut Extensions) -> Result { + if scalar.value().is_null() { return Ok(Literal { nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, literal_type: Some(LiteralType::Null(to_substrait_type( - &value.data_type(), + scalar.data_type(), true, extensions, )?)), }); } - let (literal_type, type_variation_reference) = match value { + let (literal_type, type_variation_reference) = match scalar.value() { ScalarValue::Boolean(Some(b)) => { (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) } @@ -2030,7 +2027,7 @@ fn to_substrait_literal( let keys = (0..m.keys().len()) .map(|i| { to_substrait_literal( - &ScalarValue::try_from_array(&m.keys(), i)?, + &Scalar::try_from_array(&m.keys(), i)?, extensions, ) }) @@ -2038,7 +2035,7 @@ fn to_substrait_literal( let values = (0..m.values().len()) .map(|i| { to_substrait_literal( - &ScalarValue::try_from_array(&m.values(), i)?, + &Scalar::try_from_array(&m.values(), i)?, extensions, ) }) @@ -2064,17 +2061,14 @@ fn to_substrait_literal( .columns() .iter() .map(|col| { - to_substrait_literal( - &ScalarValue::try_from_array(col, 0)?, - extensions, - ) + to_substrait_literal(&Scalar::try_from_array(col, 0)?, extensions) }) .collect::>>()?, }), DEFAULT_TYPE_VARIATION_REF, ), _ => ( - not_impl_err!("Unsupported literal: {value:?}")?, + not_impl_err!("Unsupported literal: {scalar:?}")?, DEFAULT_TYPE_VARIATION_REF, ), }; @@ -2095,10 +2089,7 @@ fn convert_array_to_literal_list( let values = (0..nested_array.len()) .map(|i| { - to_substrait_literal( - &ScalarValue::try_from_array(&nested_array, i)?, - extensions, - ) + to_substrait_literal(&Scalar::try_from_array(&nested_array, i)?, extensions) }) .collect::>>()?; @@ -2120,7 +2111,7 @@ fn convert_array_to_literal_list( } fn to_substrait_literal_expr( - value: &ScalarValue, + value: &Scalar, extensions: &mut Extensions, ) -> Result { let literal = to_substrait_literal(value, extensions)?; @@ -2357,6 +2348,7 @@ mod test { println!("Checking round trip of {scalar:?}"); let mut extensions = Extensions::default(); + let scalar = Scalar::from(scalar); let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; let roundtrip_scalar = from_substrait_literal_without_names(&substrait_literal, &extensions)?; @@ -2371,6 +2363,8 @@ mod test { let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new( 17, 25, 1234567890, ))); + + let scalar = Scalar::from(scalar); let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; let roundtrip_scalar = from_substrait_literal_without_names(&substrait_literal, &extensions)?; @@ -2545,7 +2539,7 @@ mod test { let ctx = SessionContext::new(); // One expression, empty input schema - let expr = Expr::Literal(ScalarValue::Int32(Some(42))); + let expr = Expr::from(ScalarValue::Int32(Some(42))); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let substrait =