Skip to content

Commit 3242780

Browse files
committed
first draft
Signed-off-by: jayzhan211 <[email protected]>
1 parent 453a45a commit 3242780

File tree

5 files changed

+151
-29
lines changed

5 files changed

+151
-29
lines changed

datafusion/core/tests/user_defined/user_defined_aggregates.rs

+14
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,20 @@ async fn simple_udaf() -> Result<()> {
255255
Ok(())
256256
}
257257

258+
#[tokio::test]
259+
async fn test_struct_join() {
260+
let TestContext { ctx, test_state: _ } = TestContext::new();
261+
let sql = "SElECT l.result as result FROM (SELECT first(value, time) as result from t) as l JOIN (SELECT first(value, time) as result from t) as r ON l.result = r.result";
262+
let expected = [
263+
"+------------------------------------------------+",
264+
"| result |",
265+
"+------------------------------------------------+",
266+
"| {value: 2.0, time: 1970-01-01T00:00:00.000002} |",
267+
"+------------------------------------------------+",
268+
];
269+
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
270+
}
271+
258272
#[tokio::test]
259273
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
260274
let ctx = SessionContext::new();

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

+83-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
// under the License.
1717

1818
use arrow::compute::kernels::numeric::add;
19-
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
19+
use arrow_array::{
20+
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch,
21+
UInt8Array,
22+
};
2023
use arrow_schema::DataType::Float64;
2124
use arrow_schema::{DataType, Field, Schema};
2225
use datafusion::prelude::*;
@@ -26,12 +29,15 @@ use datafusion_common::{
2629
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
2730
plan_err, DataFusionError, ExprSchema, Result, ScalarValue,
2831
};
32+
use datafusion_common::{DFField, DFSchema};
2933
use datafusion_expr::{
3034
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
31-
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
35+
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Simplified, Volatility,
3236
};
37+
3338
use rand::{thread_rng, Rng};
3439
use std::any::Any;
40+
use std::collections::HashMap;
3541
use std::iter;
3642
use std::sync::Arc;
3743

@@ -498,6 +504,81 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
498504
Ok(())
499505
}
500506

507+
#[derive(Debug)]
508+
struct CastToI64UDF {
509+
signature: Signature,
510+
}
511+
512+
impl CastToI64UDF {
513+
fn new() -> Self {
514+
Self {
515+
signature: Signature::any(1, Volatility::Immutable),
516+
}
517+
}
518+
}
519+
520+
impl ScalarUDFImpl for CastToI64UDF {
521+
fn as_any(&self) -> &dyn Any {
522+
self
523+
}
524+
fn name(&self) -> &str {
525+
"cast_to_i64"
526+
}
527+
fn signature(&self) -> &Signature {
528+
&self.signature
529+
}
530+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
531+
Ok(DataType::Int64)
532+
}
533+
// Wrap with Expr::Cast() to Int64
534+
fn simplify(&self, args: Vec<Expr>) -> Result<Simplified> {
535+
let dfs = DFSchema::new_with_metadata(
536+
vec![DFField::new(Some("t"), "x", DataType::Float32, true)],
537+
HashMap::default(),
538+
)?;
539+
let e = args[0].clone();
540+
let casted_expr = e.cast_to(&DataType::Int64, &dfs)?;
541+
Ok(Simplified::Rewritten(casted_expr))
542+
}
543+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
544+
Ok(args.get(0).unwrap().clone())
545+
}
546+
}
547+
548+
#[tokio::test]
549+
async fn test_user_defined_functions_cast_to_i64() -> Result<()> {
550+
let ctx = SessionContext::new();
551+
552+
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, false)]));
553+
554+
let batch = RecordBatch::try_new(
555+
schema,
556+
vec![Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]))],
557+
)?;
558+
559+
ctx.register_batch("t", batch)?;
560+
561+
let cast_to_i64_udf = ScalarUDF::from(CastToI64UDF::new());
562+
ctx.register_udf(cast_to_i64_udf);
563+
564+
let result = plan_and_collect(&ctx, "SELECT cast_to_i64(x) FROM t").await?;
565+
566+
assert_batches_eq!(
567+
&[
568+
"+------------------+",
569+
"| cast_to_i64(t.x) |",
570+
"+------------------+",
571+
"| 1 |",
572+
"| 2 |",
573+
"| 3 |",
574+
"+------------------+"
575+
],
576+
&result
577+
);
578+
579+
Ok(())
580+
}
581+
501582
#[derive(Debug)]
502583
struct TakeUDF {
503584
signature: Signature,

datafusion/expr/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ pub use signature::{
8080
};
8181
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
8282
pub use udaf::{AggregateUDF, AggregateUDFImpl};
83-
pub use udf::{ScalarUDF, ScalarUDFImpl};
83+
pub use udf::{ScalarUDF, ScalarUDFImpl, Simplified};
8484
pub use udwf::{WindowUDF, WindowUDFImpl};
8585
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
8686

datafusion/expr/src/udf.rs

+17
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ use std::fmt::Debug;
3030
use std::fmt::Formatter;
3131
use std::sync::Arc;
3232

33+
/// Was the expression simplified?
34+
pub enum Simplified {
35+
/// The function call was simplified to an entirely new Expr
36+
Rewritten(Expr),
37+
/// the function call could not be simplified, and the arguments
38+
/// are return unmodified
39+
Original(Vec<Expr>),
40+
}
41+
3342
/// Logical representation of a Scalar User Defined Function.
3443
///
3544
/// A scalar function produces a single row output for each row of input. This
@@ -160,6 +169,10 @@ impl ScalarUDF {
160169
self.inner.return_type_from_exprs(args, schema)
161170
}
162171

172+
pub fn simplify(&self, args: Vec<Expr>) -> Result<Simplified> {
173+
self.inner.simplify(args)
174+
}
175+
163176
/// Invoke the function on `args`, returning the appropriate result.
164177
///
165178
/// See [`ScalarUDFImpl::invoke`] for more details.
@@ -337,6 +350,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
337350
fn monotonicity(&self) -> Result<Option<FuncMonotonicity>> {
338351
Ok(None)
339352
}
353+
354+
fn simplify(&self, args: Vec<Expr>) -> Result<Simplified> {
355+
Ok(Simplified::Original(args))
356+
}
340357
}
341358

342359
/// ScalarUDF that adds an alias to the underlying function. It is better to

datafusion/physical-expr/src/planner.rs

+36-26
Original file line numberDiff line numberDiff line change
@@ -258,34 +258,44 @@ pub fn create_physical_expr(
258258
)))
259259
}
260260

261-
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
262-
let physical_args = args
263-
.iter()
264-
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
265-
.collect::<Result<Vec<_>>>()?;
266-
match func_def {
267-
ScalarFunctionDefinition::BuiltIn(fun) => {
268-
functions::create_physical_expr(
269-
fun,
270-
&physical_args,
271-
input_schema,
272-
execution_props,
273-
)
274-
}
275-
ScalarFunctionDefinition::UDF(fun) => {
276-
let return_type = fun.return_type_from_exprs(args, input_dfschema)?;
261+
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
262+
ScalarFunctionDefinition::BuiltIn(fun) => {
263+
let physical_args = args
264+
.iter()
265+
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
266+
.collect::<Result<Vec<_>>>()?;
277267

278-
udf::create_physical_expr(
279-
fun.clone().as_ref(),
280-
&physical_args,
281-
return_type,
282-
)
283-
}
284-
ScalarFunctionDefinition::Name(_) => {
285-
internal_err!("Function `Expr` with name should be resolved.")
286-
}
268+
functions::create_physical_expr(
269+
fun,
270+
&physical_args,
271+
input_schema,
272+
execution_props,
273+
)
287274
}
288-
}
275+
ScalarFunctionDefinition::UDF(fun) => {
276+
let args = match fun.simplify(args.to_owned())? {
277+
datafusion_expr::Simplified::Original(args) => args,
278+
datafusion_expr::Simplified::Rewritten(expr) => vec![expr],
279+
};
280+
281+
let physical_args = args
282+
.iter()
283+
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
284+
.collect::<Result<Vec<_>>>()?;
285+
286+
let return_type =
287+
fun.return_type_from_exprs(args.as_slice(), input_dfschema)?;
288+
289+
udf::create_physical_expr(
290+
fun.clone().as_ref(),
291+
&physical_args,
292+
return_type,
293+
)
294+
}
295+
ScalarFunctionDefinition::Name(_) => {
296+
internal_err!("Function `Expr` with name should be resolved.")
297+
}
298+
},
289299
Expr::Between(Between {
290300
expr,
291301
negated,

0 commit comments

Comments
 (0)