Skip to content

Commit ea01e56

Browse files
milenkovicmalamb
andauthored
Add plugable handler for CREATE FUNCTION (#9333)
* Add plugable function factory * cover `DROP FUNCTION` as well ... ... partially, as `SessionState` does not expose unregister_udf at the moment. * update documentation * fix doc test * Address PR comments (code organization) * Address PR comments (factory interface) * fix test after rebase * `remove`'s gone from the trait ... ... `DROP FUNCTION` will look for function name in all available registries (udf, udaf, udwf). `remove` may be necessary if UDaF and UDwF do not get `simplify` method from #9304. * Rename FunctionDefinition and export it ... FunctionDefinition already exists, DefinitionStatement makes more sense. * Update datafusion/expr/src/logical_plan/ddl.rs Co-authored-by: Andrew Lamb <[email protected]> * Update datafusion/core/src/execution/context/mod.rs Co-authored-by: Andrew Lamb <[email protected]> * Update datafusion/core/tests/user_defined/user_defined_scalar_functions.rs Co-authored-by: Andrew Lamb <[email protected]> * Update datafusion/expr/src/logical_plan/ddl.rs Co-authored-by: Andrew Lamb <[email protected]> * resolve part of follow up comments * Qualified functions are not supported anymore * update docs and todos * fix clippy * address additional comments * Add sqllogicteset for CREATE/DROP function * Add coverage for DROP FUNCTION IF EXISTS * fix multiline error * revert dialect back to generic in test ... ... as `create function` gets support in latest sqlparser. * fmt --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 3aba67e commit ea01e56

File tree

8 files changed

+498
-20
lines changed

8 files changed

+498
-20
lines changed

datafusion/core/src/execution/context/mod.rs

+94-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ use crate::datasource::{
7373
};
7474
use crate::error::{DataFusionError, Result};
7575
use crate::logical_expr::{
76-
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable,
77-
CreateView, DropCatalogSchema, DropTable, DropView, Explain, LogicalPlan,
78-
LogicalPlanBuilder, SetVariable, TableSource, TableType, UNNAMED_TABLE,
76+
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
77+
CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, DropView,
78+
Explain, LogicalPlan, LogicalPlanBuilder, SetVariable, TableSource, TableType,
79+
UNNAMED_TABLE,
7980
};
8081
use crate::optimizer::OptimizerRule;
8182
use datafusion_sql::{
@@ -489,6 +490,8 @@ impl SessionContext {
489490
DdlStatement::DropTable(cmd) => self.drop_table(cmd).await,
490491
DdlStatement::DropView(cmd) => self.drop_view(cmd).await,
491492
DdlStatement::DropCatalogSchema(cmd) => self.drop_schema(cmd).await,
493+
DdlStatement::CreateFunction(cmd) => self.create_function(cmd).await,
494+
DdlStatement::DropFunction(cmd) => self.drop_function(cmd).await,
492495
},
493496
// TODO what about the other statements (like TransactionStart and TransactionEnd)
494497
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
@@ -794,6 +797,55 @@ impl SessionContext {
794797
Ok(false)
795798
}
796799

800+
async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
801+
let function = {
802+
let state = self.state.read().clone();
803+
let function_factory = &state.function_factory;
804+
805+
match function_factory {
806+
Some(f) => f.create(state.config(), stmt).await?,
807+
_ => Err(DataFusionError::Configuration(
808+
"Function factory has not been configured".into(),
809+
))?,
810+
}
811+
};
812+
813+
match function {
814+
RegisterFunction::Scalar(f) => {
815+
self.state.write().register_udf(f)?;
816+
}
817+
RegisterFunction::Aggregate(f) => {
818+
self.state.write().register_udaf(f)?;
819+
}
820+
RegisterFunction::Window(f) => {
821+
self.state.write().register_udwf(f)?;
822+
}
823+
RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
824+
};
825+
826+
self.return_empty_dataframe()
827+
}
828+
829+
async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
830+
// we don't know function type at this point
831+
// decision has been made to drop all functions
832+
let mut dropped = false;
833+
dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
834+
dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
835+
dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();
836+
837+
// DROP FUNCTION IF EXISTS drops the specified function only if that
838+
// function exists and in this way, it avoids error. While the DROP FUNCTION
839+
// statement also performs the same function, it throws an
840+
// error if the function does not exist.
841+
842+
if !stmt.if_exists && !dropped {
843+
exec_err!("Function does not exist")
844+
} else {
845+
self.return_empty_dataframe()
846+
}
847+
}
848+
797849
/// Registers a variable provider within this context.
798850
pub fn register_variable(
799851
&self,
@@ -1261,7 +1313,30 @@ impl QueryPlanner for DefaultQueryPlanner {
12611313
.await
12621314
}
12631315
}
1316+
/// A pluggable interface to handle `CREATE FUNCTION` statements
1317+
/// and interact with [SessionState] to registers new udf, udaf or udwf.
1318+
1319+
#[async_trait]
1320+
pub trait FunctionFactory: Sync + Send {
1321+
/// Handles creation of user defined function specified in [CreateFunction] statement
1322+
async fn create(
1323+
&self,
1324+
state: &SessionConfig,
1325+
statement: CreateFunction,
1326+
) -> Result<RegisterFunction>;
1327+
}
12641328

1329+
/// Type of function to create
1330+
pub enum RegisterFunction {
1331+
/// Scalar user defined function
1332+
Scalar(Arc<ScalarUDF>),
1333+
/// Aggregate user defined function
1334+
Aggregate(Arc<AggregateUDF>),
1335+
/// Window user defined function
1336+
Window(Arc<WindowUDF>),
1337+
/// Table user defined function
1338+
Table(String, Arc<dyn TableFunctionImpl>),
1339+
}
12651340
/// Execution context for registering data sources and executing queries.
12661341
/// See [`SessionContext`] for a higher level API.
12671342
///
@@ -1306,6 +1381,12 @@ pub struct SessionState {
13061381
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
13071382
/// Runtime environment
13081383
runtime_env: Arc<RuntimeEnv>,
1384+
1385+
/// [FunctionFactory] to support pluggable user defined function handler.
1386+
///
1387+
/// It will be invoked on `CREATE FUNCTION` statements.
1388+
/// thus, changing dialect o PostgreSql is required
1389+
function_factory: Option<Arc<dyn FunctionFactory>>,
13091390
}
13101391

13111392
impl Debug for SessionState {
@@ -1392,6 +1473,7 @@ impl SessionState {
13921473
execution_props: ExecutionProps::new(),
13931474
runtime_env: runtime,
13941475
table_factories,
1476+
function_factory: None,
13951477
};
13961478

13971479
// register built in functions
@@ -1568,6 +1650,15 @@ impl SessionState {
15681650
self
15691651
}
15701652

1653+
/// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements
1654+
pub fn with_function_factory(
1655+
mut self,
1656+
function_factory: Arc<dyn FunctionFactory>,
1657+
) -> Self {
1658+
self.function_factory = Some(function_factory);
1659+
self
1660+
}
1661+
15711662
/// Replace the extension [`SerializerRegistry`]
15721663
pub fn with_serializer_registry(
15731664
mut self,

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

+128-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
use arrow::compute::kernels::numeric::add;
1919
use arrow_array::{
20-
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
20+
Array, ArrayRef, ArrowNativeTypeOp, Float32Array, Float64Array, Int32Array,
21+
RecordBatch, UInt8Array,
2122
};
2223
use arrow_schema::DataType::Float64;
2324
use arrow_schema::{DataType, Field, Schema};
25+
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
2426
use datafusion::prelude::*;
2527
use datafusion::{execution::registry::FunctionRegistry, test_util};
2628
use datafusion_common::cast::as_float64_array;
@@ -31,10 +33,12 @@ use datafusion_common::{
3133
use datafusion_expr::simplify::ExprSimplifyResult;
3234
use datafusion_expr::simplify::SimplifyInfo;
3335
use datafusion_expr::{
34-
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
36+
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable,
3537
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
3638
};
39+
use parking_lot::Mutex;
3740

41+
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
3842
use rand::{thread_rng, Rng};
3943
use std::any::Any;
4044
use std::iter;
@@ -735,6 +739,128 @@ async fn verify_udf_return_type() -> Result<()> {
735739
Ok(())
736740
}
737741

742+
#[derive(Debug, Default)]
743+
struct MockFunctionFactory {
744+
pub captured_expr: Mutex<Option<Expr>>,
745+
}
746+
747+
#[async_trait::async_trait]
748+
impl FunctionFactory for MockFunctionFactory {
749+
#[doc = r" Crates and registers a function from [CreateFunction] statement"]
750+
#[must_use]
751+
#[allow(clippy::type_complexity, clippy::type_repetition_in_bounds)]
752+
async fn create(
753+
&self,
754+
_config: &SessionConfig,
755+
statement: CreateFunction,
756+
) -> datafusion::error::Result<RegisterFunction> {
757+
// In this example, we always create a function that adds its arguments
758+
// with the name specified in `CREATE FUNCTION`. In a real implementation
759+
// the body of the created UDF would also likely be a function of the contents
760+
// of the `CreateFunction`
761+
let mock_add = Arc::new(|args: &[datafusion_expr::ColumnarValue]| {
762+
let args = datafusion_expr::ColumnarValue::values_to_arrays(args)?;
763+
let base =
764+
datafusion_common::cast::as_float64_array(&args[0]).expect("cast failed");
765+
let exponent =
766+
datafusion_common::cast::as_float64_array(&args[1]).expect("cast failed");
767+
768+
let array = base
769+
.iter()
770+
.zip(exponent.iter())
771+
.map(|(base, exponent)| match (base, exponent) {
772+
(Some(base), Some(exponent)) => Some(base.add_wrapping(exponent)),
773+
_ => None,
774+
})
775+
.collect::<arrow_array::Float64Array>();
776+
Ok(datafusion_expr::ColumnarValue::from(
777+
Arc::new(array) as arrow_array::ArrayRef
778+
))
779+
});
780+
781+
let args = statement.args.unwrap();
782+
let mock_udf = create_udf(
783+
&statement.name,
784+
vec![args[0].data_type.clone(), args[1].data_type.clone()],
785+
Arc::new(statement.return_type.unwrap()),
786+
datafusion_expr::Volatility::Immutable,
787+
mock_add,
788+
);
789+
790+
// capture expression so we can verify
791+
// it has been parsed
792+
*self.captured_expr.lock() = statement.params.return_;
793+
794+
Ok(RegisterFunction::Scalar(Arc::new(mock_udf)))
795+
}
796+
}
797+
798+
#[tokio::test]
799+
async fn create_scalar_function_from_sql_statement() -> Result<()> {
800+
let function_factory = Arc::new(MockFunctionFactory::default());
801+
let runtime_config = RuntimeConfig::new();
802+
let runtime_environment = RuntimeEnv::new(runtime_config)?;
803+
804+
let session_config = SessionConfig::new();
805+
let state =
806+
SessionState::new_with_config_rt(session_config, Arc::new(runtime_environment))
807+
.with_function_factory(function_factory.clone());
808+
809+
let ctx = SessionContext::new_with_state(state);
810+
let options = SQLOptions::new().with_allow_ddl(false);
811+
812+
let sql = r#"
813+
CREATE FUNCTION better_add(DOUBLE, DOUBLE)
814+
RETURNS DOUBLE
815+
RETURN $1 + $2
816+
"#;
817+
818+
// try to `create function` when sql options have allow ddl disabled
819+
assert!(ctx.sql_with_options(sql, options).await.is_err());
820+
821+
// Create the `better_add` function dynamically via CREATE FUNCTION statement
822+
assert!(ctx.sql(sql).await.is_ok());
823+
// try to `drop function` when sql options have allow ddl disabled
824+
assert!(ctx
825+
.sql_with_options("drop function better_add", options)
826+
.await
827+
.is_err());
828+
829+
ctx.sql("select better_add(2.0, 2.0)").await?.show().await?;
830+
831+
// check if we sql expr has been converted to datafusion expr
832+
let captured_expression = function_factory.captured_expr.lock().clone().unwrap();
833+
assert_eq!("$1 + $2", captured_expression.to_string());
834+
835+
// statement drops function
836+
assert!(ctx.sql("drop function better_add").await.is_ok());
837+
// no function, it panics
838+
assert!(ctx.sql("drop function better_add").await.is_err());
839+
// no function, it dies not care
840+
assert!(ctx.sql("drop function if exists better_add").await.is_ok());
841+
// query should fail as there is no function
842+
assert!(ctx.sql("select better_add(2.0, 2.0)").await.is_err());
843+
844+
// tests expression parsing
845+
// if expression is not correct
846+
let bad_expression_sql = r#"
847+
CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE)
848+
RETURNS DOUBLE
849+
RETURN $1 $3
850+
"#;
851+
assert!(ctx.sql(bad_expression_sql).await.is_err());
852+
853+
// tests bad function definition
854+
let bad_definition_sql = r#"
855+
CREATE FUNCTION bad_definition_fun(DOUBLE, DOUBLE)
856+
RET BAD_TYPE
857+
RETURN $1 + $3
858+
"#;
859+
assert!(ctx.sql(bad_definition_sql).await.is_err());
860+
861+
Ok(())
862+
}
863+
738864
fn create_udf_context() -> SessionContext {
739865
let ctx = SessionContext::new();
740866
// register a custom UDF

0 commit comments

Comments
 (0)