Skip to content

Commit cade25d

Browse files
committed
FunctionFactory usage example
1 parent ea01e56 commit cade25d

File tree

1 file changed

+243
-0
lines changed

1 file changed

+243
-0
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{ArrayRef, Int64Array, RecordBatch};
19+
use datafusion::error::Result;
20+
use datafusion::execution::config::SessionConfig;
21+
use datafusion::execution::context::{
22+
FunctionFactory, RegisterFunction, SessionContext, SessionState,
23+
};
24+
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
25+
use datafusion_common::tree_node::{Transformed, TreeNode};
26+
use datafusion_common::{exec_err, internal_err, DataFusionError};
27+
use datafusion_expr::simplify::ExprSimplifyResult;
28+
use datafusion_expr::simplify::SimplifyInfo;
29+
use datafusion_expr::{CreateFunction, Expr, ScalarUDF, ScalarUDFImpl, Signature};
30+
use std::result::Result as RResult;
31+
use std::sync::Arc;
32+
33+
/// This example shows how to utilize [FunctionFactory] to register
34+
/// `CREATE FUNCTION` handler. Apart from [FunctionFactory] this
35+
/// example covers [ScalarUDFImpl::simplify()] usage and synergy
36+
/// between those two functionality.
37+
///
38+
/// This example is rather simple, there are many edge cases to be covered
39+
///
40+
41+
#[tokio::main]
42+
async fn main() -> Result<()> {
43+
let runtime_config = RuntimeConfig::new();
44+
let runtime_environment = RuntimeEnv::new(runtime_config)?;
45+
46+
let session_config = SessionConfig::new();
47+
let state =
48+
SessionState::new_with_config_rt(session_config, Arc::new(runtime_environment))
49+
// register custom function factory
50+
.with_function_factory(Arc::new(CustomFunctionFactory::default()));
51+
52+
let ctx = SessionContext::new_with_state(state);
53+
54+
let sql = r#"
55+
CREATE FUNCTION f1(BIGINT)
56+
RETURNS BIGINT
57+
RETURN $1 + 1
58+
"#;
59+
60+
ctx.sql(sql).await?.show().await?;
61+
62+
let sql = r#"
63+
CREATE FUNCTION f2(BIGINT, BIGINT)
64+
RETURNS BIGINT
65+
RETURN $1 + f1($2)
66+
"#;
67+
68+
ctx.sql(sql).await?.show().await?;
69+
70+
let sql = r#"
71+
SELECT f1(1)
72+
"#;
73+
74+
ctx.sql(sql).await?.show().await?;
75+
76+
let sql = r#"
77+
SELECT f2(1, 2)
78+
"#;
79+
80+
ctx.sql(sql).await?.show().await?;
81+
82+
let a: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3, 4]));
83+
let b: ArrayRef = Arc::new(Int64Array::from(vec![10, 20, 30, 40]));
84+
let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?;
85+
86+
ctx.register_batch("t", batch)?;
87+
88+
let sql = r#"
89+
SELECT f2(a, b) from t
90+
"#;
91+
92+
ctx.sql(sql).await?.show().await?;
93+
94+
ctx.sql("DROP FUNCTION f1").await?.show().await?;
95+
96+
ctx.sql("DROP FUNCTION f2").await?.show().await?;
97+
98+
Ok(())
99+
}
100+
101+
#[derive(Debug, Default)]
102+
struct CustomFunctionFactory {}
103+
104+
#[async_trait::async_trait]
105+
impl FunctionFactory for CustomFunctionFactory {
106+
async fn create(
107+
&self,
108+
_state: &SessionConfig,
109+
statement: CreateFunction,
110+
) -> Result<RegisterFunction> {
111+
let f: ScalarFunctionWrapper = statement.try_into()?;
112+
113+
Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(f))))
114+
}
115+
}
116+
// a wrapper type to be used to register
117+
// custom function to datafusion context
118+
//
119+
// it also defines custom [ScalarUDFImpl::simplify()]
120+
// to replace ScalarUDF expression with one instance contains.
121+
#[derive(Debug)]
122+
struct ScalarFunctionWrapper {
123+
name: String,
124+
expr: Expr,
125+
signature: Signature,
126+
return_type: arrow_schema::DataType,
127+
}
128+
129+
impl ScalarUDFImpl for ScalarFunctionWrapper {
130+
fn as_any(&self) -> &dyn std::any::Any {
131+
self
132+
}
133+
134+
fn name(&self) -> &str {
135+
&self.name
136+
}
137+
138+
fn signature(&self) -> &datafusion_expr::Signature {
139+
&self.signature
140+
}
141+
142+
fn return_type(
143+
&self,
144+
_arg_types: &[arrow_schema::DataType],
145+
) -> Result<arrow_schema::DataType> {
146+
Ok(self.return_type.clone())
147+
}
148+
149+
fn invoke(
150+
&self,
151+
_args: &[datafusion_expr::ColumnarValue],
152+
) -> Result<datafusion_expr::ColumnarValue> {
153+
internal_err!("This function should not get invoked!")
154+
}
155+
156+
fn simplify(
157+
&self,
158+
args: Vec<Expr>,
159+
_info: &dyn SimplifyInfo,
160+
) -> Result<ExprSimplifyResult> {
161+
let replacement = Self::replacement(&self.expr, &args)?;
162+
163+
Ok(ExprSimplifyResult::Simplified(replacement))
164+
}
165+
166+
fn aliases(&self) -> &[String] {
167+
&[]
168+
}
169+
170+
fn monotonicity(&self) -> Result<Option<datafusion_expr::FuncMonotonicity>> {
171+
Ok(None)
172+
}
173+
}
174+
175+
impl ScalarFunctionWrapper {
176+
// replaces placeholders with actual arguments
177+
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
178+
let result = expr.clone().transform(&|e| {
179+
let r = match e {
180+
Expr::Placeholder(placeholder) => {
181+
let placeholder_position =
182+
Self::parse_placeholder_identifier(&placeholder.id)?;
183+
if placeholder_position < args.len() {
184+
Transformed::yes(args[placeholder_position].clone())
185+
} else {
186+
exec_err!(
187+
"Function argument {} not provided, argument missing!",
188+
placeholder.id
189+
)?
190+
}
191+
}
192+
_ => Transformed::no(e),
193+
};
194+
195+
Ok(r)
196+
})?;
197+
198+
Ok(result.data)
199+
}
200+
// Finds placeholder identifier.
201+
// placeholders are in `$X` format where X >= 1
202+
fn parse_placeholder_identifier(placeholder: &str) -> Result<usize> {
203+
if let Some(value) = placeholder.strip_prefix('$') {
204+
Ok(value.parse().map(|v: usize| v - 1).map_err(|e| {
205+
DataFusionError::Execution(format!(
206+
"Placeholder `{}` parsing error: {}!",
207+
placeholder, e
208+
))
209+
})?)
210+
} else {
211+
exec_err!("Placeholder should start with `$`!")
212+
}
213+
}
214+
}
215+
216+
impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
217+
type Error = DataFusionError;
218+
219+
fn try_from(definition: CreateFunction) -> RResult<Self, Self::Error> {
220+
Ok(Self {
221+
name: definition.name,
222+
expr: definition
223+
.params
224+
.return_
225+
.expect("Expression has to be defined!"),
226+
return_type: definition
227+
.return_type
228+
.expect("Return type has to be defined!"),
229+
signature: Signature::exact(
230+
definition
231+
.args
232+
.unwrap_or_default()
233+
.into_iter()
234+
.map(|a| a.data_type)
235+
.collect(),
236+
definition
237+
.params
238+
.behavior
239+
.unwrap_or(datafusion_expr::Volatility::Volatile),
240+
),
241+
})
242+
}
243+
}

0 commit comments

Comments
 (0)