Skip to content

Commit 2873fd0

Browse files
jayzhan211alamb
andauthored
Add a ScalarUDFImpl::simplfy() API, move SimplifyInfo et al to datafusion_expr (#9304)
* first draft Signed-off-by: jayzhan211 <[email protected]> * clippy Signed-off-by: jayzhan211 <[email protected]> * add comments Signed-off-by: jayzhan211 <[email protected]> * move to optimize rule Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * fix explain test Signed-off-by: jayzhan211 <[email protected]> * move to simplifier Signed-off-by: jayzhan211 <[email protected]> * pass with schema Signed-off-by: jayzhan211 <[email protected]> * fix explain Signed-off-by: jayzhan211 <[email protected]> * fix doc Signed-off-by: jayzhan211 <[email protected]> * move to expr Signed-off-by: jayzhan211 <[email protected]> * change simplify signature Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * fix doc Signed-off-by: jayzhan211 <[email protected]> * fix doc Signed-off-by: jayzhan211 <[email protected]> * Update datafusion/expr/src/udf.rs * Add backwards compatibile uses, inline FunctionSimplifier, rename to ExprSimplifyResult * Remove DFSchema from SimplifyInfo * Avoid requiring argument copies * Improve docs * fix link * fix doc test * Update datafusion/physical-expr/src/lib.rs * Change example simplify to always simplify its argument * Clarify comment --------- Signed-off-by: jayzhan211 <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 3854419 commit 2873fd0

File tree

35 files changed

+287
-99
lines changed

35 files changed

+287
-99
lines changed

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-examples/examples/expr_api.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ use arrow::record_batch::RecordBatch;
2424
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
2525
use datafusion::common::{DFField, DFSchema};
2626
use datafusion::error::Result;
27-
use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
28-
use datafusion::physical_expr::execution_props::ExecutionProps;
27+
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
2928
use datafusion::physical_expr::{
3029
analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr,
3130
};
3231
use datafusion::prelude::*;
3332
use datafusion_common::{ScalarValue, ToDFSchema};
33+
use datafusion_expr::execution_props::ExecutionProps;
3434
use datafusion_expr::expr::BinaryExpr;
3535
use datafusion_expr::interval_arithmetic::Interval;
36+
use datafusion_expr::simplify::SimplifyContext;
3637
use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
3738

3839
/// This example demonstrates the DataFusion [`Expr`] API.

datafusion-examples/examples/simple_udtf.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ use datafusion::physical_plan::memory::MemoryExec;
2828
use datafusion::physical_plan::ExecutionPlan;
2929
use datafusion::prelude::SessionContext;
3030
use datafusion_common::{plan_err, ScalarValue};
31+
use datafusion_expr::simplify::SimplifyContext;
3132
use datafusion_expr::{Expr, TableType};
32-
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
33+
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
3334
use std::fs::File;
3435
use std::io::Seek;
3536
use std::path::Path;

datafusion/core/src/datasource/listing/helpers.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,10 @@ use arrow::{
3333
use arrow_schema::Fields;
3434
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
3535
use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError};
36+
use datafusion_expr::execution_props::ExecutionProps;
3637
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
3738
use datafusion_physical_expr::create_physical_expr;
38-
use datafusion_physical_expr::execution_props::ExecutionProps;
39-
40-
use futures::stream::{BoxStream, FuturesUnordered};
41-
use futures::{StreamExt, TryStreamExt};
39+
use futures::stream::{BoxStream, FuturesUnordered, StreamExt, TryStreamExt};
4240
use log::{debug, trace};
4341
use object_store::path::Path;
4442
use object_store::{ObjectMeta, ObjectStore};

datafusion/core/src/datasource/physical_plan/parquet/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,13 +800,14 @@ mod tests {
800800
ArrayRef, Date64Array, Int32Array, Int64Array, Int8Array, StringArray,
801801
StructArray,
802802
};
803+
803804
use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder};
804805
use arrow::record_batch::RecordBatch;
805806
use arrow_schema::Fields;
806807
use datafusion_common::{assert_contains, FileType, GetExt, ScalarValue, ToDFSchema};
808+
use datafusion_expr::execution_props::ExecutionProps;
807809
use datafusion_expr::{col, lit, when, Expr};
808810
use datafusion_physical_expr::create_physical_expr;
809-
use datafusion_physical_expr::execution_props::ExecutionProps;
810811

811812
use chrono::{TimeZone, Utc};
812813
use futures::StreamExt;

datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,9 @@ mod test {
401401
use super::*;
402402
use arrow::datatypes::Field;
403403
use datafusion_common::ToDFSchema;
404+
use datafusion_expr::execution_props::ExecutionProps;
404405
use datafusion_expr::{cast, col, lit, Expr};
405406
use datafusion_physical_expr::create_physical_expr;
406-
use datafusion_physical_expr::execution_props::ExecutionProps;
407407
use parquet::arrow::parquet_to_arrow_schema;
408408
use parquet::file::reader::{FileReader, SerializedFileReader};
409409
use rand::prelude::*;

datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ mod tests {
346346
use arrow::datatypes::Schema;
347347
use arrow::datatypes::{DataType, Field};
348348
use datafusion_common::{Result, ToDFSchema};
349+
use datafusion_expr::execution_props::ExecutionProps;
349350
use datafusion_expr::{cast, col, lit, Expr};
350-
use datafusion_physical_expr::execution_props::ExecutionProps;
351351
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
352352
use parquet::arrow::arrow_to_parquet_schema;
353353
use parquet::arrow::async_reader::ParquetObjectReader;

datafusion/core/src/execution/context/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ use datafusion_common::{
4343
tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor},
4444
};
4545
use datafusion_execution::registry::SerializerRegistry;
46+
pub use datafusion_expr::execution_props::ExecutionProps;
47+
use datafusion_expr::var_provider::is_system_variables;
4648
use datafusion_expr::{
4749
logical_plan::{DdlStatement, Statement},
4850
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
4951
};
50-
pub use datafusion_physical_expr::execution_props::ExecutionProps;
51-
use datafusion_physical_expr::var_provider::is_system_variables;
5252
use parking_lot::RwLock;
5353
use std::collections::hash_map::Entry;
5454
use std::string::String;

datafusion/core/src/physical_optimizer/pruning.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,10 +1341,10 @@ mod tests {
13411341
datatypes::{DataType, TimeUnit},
13421342
};
13431343
use datafusion_common::{ScalarValue, ToDFSchema};
1344+
use datafusion_expr::execution_props::ExecutionProps;
13441345
use datafusion_expr::expr::InList;
13451346
use datafusion_expr::{cast, is_null, try_cast, Expr};
13461347
use datafusion_physical_expr::create_physical_expr;
1347-
use datafusion_physical_expr::execution_props::ExecutionProps;
13481348
use std::collections::HashMap;
13491349
use std::ops::{Not, Rem};
13501350

datafusion/core/src/test_util/parquet.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ use crate::datasource::listing::{ListingTableUrl, PartitionedFile};
2828
use crate::datasource::object_store::ObjectStoreUrl;
2929
use crate::datasource::physical_plan::{FileScanConfig, ParquetExec};
3030
use crate::error::Result;
31-
use crate::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
31+
use crate::logical_expr::execution_props::ExecutionProps;
32+
use crate::logical_expr::simplify::SimplifyContext;
33+
use crate::optimizer::simplify_expressions::ExprSimplifier;
3234
use crate::physical_expr::create_physical_expr;
33-
use crate::physical_expr::execution_props::ExecutionProps;
3435
use crate::physical_plan::filter::FilterExec;
3536
use crate::physical_plan::metrics::MetricsSet;
3637
use crate::physical_plan::ExecutionPlan;

datafusion/core/src/variable/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
//! Variable provider for `@name` and `@@name` style runtime values.
1919
20-
pub use datafusion_physical_expr::var_provider::{VarProvider, VarType};
20+
pub use datafusion_expr::var_provider::{VarProvider, VarType};

datafusion/core/tests/dataframe/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOpt
4848
use datafusion_execution::config::SessionConfig;
4949
use datafusion_execution::runtime_env::RuntimeEnv;
5050
use datafusion_expr::expr::{GroupingSet, Sort};
51+
use datafusion_expr::var_provider::{VarProvider, VarType};
5152
use datafusion_expr::{
5253
array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col,
5354
placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr,
5455
ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits,
5556
WindowFunctionDefinition,
5657
};
57-
use datafusion_physical_expr::var_provider::{VarProvider, VarType};
5858

5959
#[tokio::test]
6060
async fn test_count_wildcard_on_sort() -> Result<()> {

datafusion/core/tests/parquet/page_pruning.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ use datafusion::physical_plan::metrics::MetricValue;
2828
use datafusion::physical_plan::ExecutionPlan;
2929
use datafusion::prelude::SessionContext;
3030
use datafusion_common::{ScalarValue, Statistics, ToDFSchema};
31+
use datafusion_expr::execution_props::ExecutionProps;
3132
use datafusion_expr::{col, lit, Expr};
3233
use datafusion_physical_expr::create_physical_expr;
33-
use datafusion_physical_expr::execution_props::ExecutionProps;
3434

3535
use futures::StreamExt;
3636
use object_store::path::Path;

datafusion/core/tests/simplification.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,17 @@
2020
use arrow::datatypes::{DataType, Field, Schema};
2121
use arrow_array::{ArrayRef, Int32Array};
2222
use chrono::{DateTime, TimeZone, Utc};
23-
use datafusion::common::DFSchema;
2423
use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*};
2524
use datafusion_common::cast::as_int32_array;
2625
use datafusion_common::ScalarValue;
26+
use datafusion_common::{DFSchemaRef, ToDFSchema};
2727
use datafusion_expr::expr::ScalarFunction;
28+
use datafusion_expr::simplify::SimplifyInfo;
2829
use datafusion_expr::{
2930
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
3031
LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility,
3132
};
32-
use datafusion_optimizer::simplify_expressions::{
33-
ExprSimplifier, SimplifyExpressions, SimplifyInfo,
34-
};
33+
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
3534
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
3635
use std::sync::Arc;
3736

@@ -42,7 +41,7 @@ use std::sync::Arc;
4241
/// objects or from some other implementation
4342
struct MyInfo {
4443
/// The input schema
45-
schema: DFSchema,
44+
schema: DFSchemaRef,
4645

4746
/// Execution specific details needed for constant evaluation such
4847
/// as the current time for `now()` and [VariableProviders]
@@ -51,24 +50,27 @@ struct MyInfo {
5150

5251
impl SimplifyInfo for MyInfo {
5352
fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
54-
Ok(matches!(expr.get_type(&self.schema)?, DataType::Boolean))
53+
Ok(matches!(
54+
expr.get_type(self.schema.as_ref())?,
55+
DataType::Boolean
56+
))
5557
}
5658

5759
fn nullable(&self, expr: &Expr) -> Result<bool> {
58-
expr.nullable(&self.schema)
60+
expr.nullable(self.schema.as_ref())
5961
}
6062

6163
fn execution_props(&self) -> &ExecutionProps {
6264
&self.execution_props
6365
}
6466

6567
fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
66-
expr.get_type(&self.schema)
68+
expr.get_type(self.schema.as_ref())
6769
}
6870
}
6971

70-
impl From<DFSchema> for MyInfo {
71-
fn from(schema: DFSchema) -> Self {
72+
impl From<DFSchemaRef> for MyInfo {
73+
fn from(schema: DFSchemaRef) -> Self {
7274
Self {
7375
schema,
7476
execution_props: ExecutionProps::new(),
@@ -81,13 +83,13 @@ impl From<DFSchema> for MyInfo {
8183
/// a: Int32 (possibly with nulls)
8284
/// b: Int32
8385
/// s: Utf8
84-
fn schema() -> DFSchema {
86+
fn schema() -> DFSchemaRef {
8587
Schema::new(vec![
8688
Field::new("a", DataType::Int32, true),
8789
Field::new("b", DataType::Int32, false),
8890
Field::new("s", DataType::Utf8, false),
8991
])
90-
.try_into()
92+
.to_dfschema_ref()
9193
.unwrap()
9294
}
9395

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use arrow::compute::kernels::numeric::add;
19-
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
19+
use arrow_array::{
20+
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
21+
};
2022
use arrow_schema::DataType::Float64;
2123
use arrow_schema::{DataType, Field, Schema};
2224
use datafusion::prelude::*;
@@ -26,10 +28,13 @@ use datafusion_common::{
2628
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
2729
plan_err, ExprSchema, Result, ScalarValue,
2830
};
31+
use datafusion_expr::simplify::ExprSimplifyResult;
32+
use datafusion_expr::simplify::SimplifyInfo;
2933
use datafusion_expr::{
3034
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
3135
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
3236
};
37+
3338
use rand::{thread_rng, Rng};
3439
use std::any::Any;
3540
use std::iter;
@@ -514,6 +519,101 @@ async fn deregister_udf() -> Result<()> {
514519
Ok(())
515520
}
516521

522+
#[derive(Debug)]
523+
struct CastToI64UDF {
524+
signature: Signature,
525+
}
526+
527+
impl CastToI64UDF {
528+
fn new() -> Self {
529+
Self {
530+
signature: Signature::any(1, Volatility::Immutable),
531+
}
532+
}
533+
}
534+
535+
impl ScalarUDFImpl for CastToI64UDF {
536+
fn as_any(&self) -> &dyn Any {
537+
self
538+
}
539+
fn name(&self) -> &str {
540+
"cast_to_i64"
541+
}
542+
fn signature(&self) -> &Signature {
543+
&self.signature
544+
}
545+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
546+
Ok(DataType::Int64)
547+
}
548+
549+
// Demonstrate simplifying a UDF
550+
fn simplify(
551+
&self,
552+
mut args: Vec<Expr>,
553+
info: &dyn SimplifyInfo,
554+
) -> Result<ExprSimplifyResult> {
555+
// DataFusion should have ensured the function is called with just a
556+
// single argument
557+
assert_eq!(args.len(), 1);
558+
let arg = args.pop().unwrap();
559+
560+
// Note that Expr::cast_to requires an ExprSchema but simplify gets a
561+
// SimplifyInfo so we have to replicate some of the casting logic here.
562+
563+
let source_type = info.get_data_type(&arg)?;
564+
let new_expr = if source_type == DataType::Int64 {
565+
// the argument's data type is already the correct type
566+
arg
567+
} else {
568+
// need to use an actual cast to get the correct type
569+
Expr::Cast(datafusion_expr::Cast {
570+
expr: Box::new(arg),
571+
data_type: DataType::Int64,
572+
})
573+
};
574+
// return the newly written argument to DataFusion
575+
Ok(ExprSimplifyResult::Simplified(new_expr))
576+
}
577+
578+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
579+
unimplemented!("Function should have been simplified prior to evaluation")
580+
}
581+
}
582+
583+
#[tokio::test]
584+
async fn test_user_defined_functions_cast_to_i64() -> Result<()> {
585+
let ctx = SessionContext::new();
586+
587+
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, false)]));
588+
589+
let batch = RecordBatch::try_new(
590+
schema,
591+
vec![Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]))],
592+
)?;
593+
594+
ctx.register_batch("t", batch)?;
595+
596+
let cast_to_i64_udf = ScalarUDF::from(CastToI64UDF::new());
597+
ctx.register_udf(cast_to_i64_udf);
598+
599+
let result = plan_and_collect(&ctx, "SELECT cast_to_i64(x) FROM t").await?;
600+
601+
assert_batches_eq!(
602+
&[
603+
"+------------------+",
604+
"| cast_to_i64(t.x) |",
605+
"+------------------+",
606+
"| 1 |",
607+
"| 2 |",
608+
"| 3 |",
609+
"+------------------+"
610+
],
611+
&result
612+
);
613+
614+
Ok(())
615+
}
616+
517617
#[derive(Debug)]
518618
struct TakeUDF {
519619
signature: Signature,

datafusion/expr/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ ahash = { version = "0.8", default-features = false, features = [
4040
] }
4141
arrow = { workspace = true }
4242
arrow-array = { workspace = true }
43+
chrono = { workspace = true }
4344
datafusion-common = { workspace = true, default-features = true }
4445
paste = "^1.0"
4546
sqlparser = { workspace = true }

datafusion/physical-expr/src/execution_props.rs renamed to datafusion/expr/src/execution_props.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,12 @@ use std::sync::Arc;
2424
/// Holds per-query execution properties and data (such as statement
2525
/// starting timestamps).
2626
///
27-
/// An [`ExecutionProps`] is created each time a [`LogicalPlan`] is
27+
/// An [`ExecutionProps`] is created each time a `LogicalPlan` is
2828
/// prepared for execution (optimized). If the same plan is optimized
2929
/// multiple times, a new `ExecutionProps` is created each time.
3030
///
3131
/// It is important that this structure be cheap to create as it is
3232
/// done so during predicate pruning and expression simplification
33-
///
34-
/// [`LogicalPlan`]: datafusion_expr::LogicalPlan
3533
#[derive(Clone, Debug)]
3634
pub struct ExecutionProps {
3735
pub query_execution_start_time: DateTime<Utc>,

0 commit comments

Comments
 (0)