Skip to content

Draft: User define type coercion in LogicalBuilder #15391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions datafusion-examples/examples/sql_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use datafusion::common::{plan_err, TableReference};
use datafusion::config::ConfigOptions;
use datafusion::error::Result;
use datafusion::logical_expr::{
AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource,
WindowUDF,
AggregateUDF, Expr, LogicalPlan, LogicalPlanBuilderConfig, ScalarUDF,
TableProviderFilterPushDown, TableSource, WindowUDF,
};
use datafusion::optimizer::{
Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule,
Expand Down Expand Up @@ -135,6 +135,7 @@ struct MyContextProvider {
options: ConfigOptions,
}

impl LogicalPlanBuilderConfig for MyContextProvider {}
impl ContextProvider for MyContextProvider {
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
if name.table() == "person" {
Expand Down
60 changes: 36 additions & 24 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ use datafusion_common::{
exec_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema,
DataFusionError, ParamValues, ScalarValue, SchemaError, UnnestOptions,
};
use datafusion_expr::user_defined_builder::UserDefinedLogicalBuilder;
use datafusion_expr::{
case,
dml::InsertOp,
Expand Down Expand Up @@ -347,9 +348,14 @@ impl DataFrame {
let plan = if window_func_exprs.is_empty() {
self.plan
} else {
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?
UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.window_plan(window_func_exprs)?
.build()?
};
let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;

let project_plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), plan)
.project(expr_list)?
.build()?;

Ok(DataFrame {
session_state: self.session_state,
Expand Down Expand Up @@ -495,9 +501,10 @@ impl DataFrame {
/// # }
/// ```
pub fn filter(self, predicate: Expr) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.filter(predicate)?
.build()?;

Ok(DataFrame {
session_state: self.session_state,
plan,
Expand Down Expand Up @@ -553,7 +560,7 @@ impl DataFrame {
let aggr_expr_len = aggr_expr.len();
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan = LogicalPlanBuilder::from(self.plan)
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.with_options(options)
.aggregate(group_expr, aggr_expr)?
.build()?;
Expand All @@ -568,7 +575,9 @@ impl DataFrame {
.filter(|(idx, _)| *idx != grouping_id_pos)
.map(|(_, column)| Expr::Column(column))
.collect::<Vec<_>>();
LogicalPlanBuilder::from(plan).project(exprs)?.build()?
UserDefinedLogicalBuilder::new(self.session_state.as_ref(), plan)
.project(exprs)?
.build()?
} else {
plan
};
Expand All @@ -582,8 +591,8 @@ impl DataFrame {
/// Return a new DataFrame that adds the result of evaluating one or more
/// window functions ([`Expr::WindowFunction`]) to the existing columns
pub fn window(self, window_exprs: Vec<Expr>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.window(window_exprs)?
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.window_plan(window_exprs)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
Expand Down Expand Up @@ -621,7 +630,7 @@ impl DataFrame {
/// # }
/// ```
pub fn limit(self, skip: usize, fetch: Option<usize>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.limit(skip, fetch)?
.build()?;
Ok(DataFrame {
Expand Down Expand Up @@ -659,8 +668,8 @@ impl DataFrame {
/// # }
/// ```
pub fn union(self, dataframe: DataFrame) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.union(dataframe.plan)?
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.union(vec![dataframe.plan])?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
Expand Down Expand Up @@ -732,7 +741,9 @@ impl DataFrame {
/// # }
/// ```
pub fn distinct(self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan).distinct()?.build()?;
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.distinct()?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
Expand Down Expand Up @@ -772,7 +783,7 @@ impl DataFrame {
select_expr: Vec<Expr>,
sort_expr: Option<Vec<SortExpr>>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.distinct_on(on_expr, select_expr, sort_expr)?
.build()?;
Ok(DataFrame {
Expand Down Expand Up @@ -1025,7 +1036,9 @@ impl DataFrame {
/// # }
/// ```
pub fn sort(self, expr: Vec<SortExpr>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?;
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.sort(expr, None)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
Expand Down Expand Up @@ -1086,14 +1099,10 @@ impl DataFrame {
right_cols: &[&str],
filter: Option<Expr>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.join(
right.plan,
join_type,
(left_cols.to_vec(), right_cols.to_vec()),
filter,
)?
let plan = UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.join(right.plan, join_type, (left_cols.to_vec(), right_cols.to_vec()), filter)?
.build()?;

Ok(DataFrame {
session_state: self.session_state,
plan,
Expand Down Expand Up @@ -1768,7 +1777,9 @@ impl DataFrame {
} else {
(
Some(window_func_exprs[0].to_string()),
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?,
UserDefinedLogicalBuilder::new(self.session_state.as_ref(), self.plan)
.window_plan(window_func_exprs)?
.build()?,
)
};

Expand Down Expand Up @@ -1796,9 +1807,10 @@ impl DataFrame {
fields.push((new_column, true));
}

let project_plan = LogicalPlanBuilder::from(plan)
.project_with_validation(fields)?
.build()?;
let project_plan =
UserDefinedLogicalBuilder::new(self.session_state.as_ref(), plan)
.project_with_validation(fields)?
.build()?;

Ok(DataFrame {
session_state: self.session_state,
Expand Down
34 changes: 32 additions & 2 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::planner::{ExprPlanner, TypePlanner};
use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry};
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::type_coercion::TypeCoercion;
use datafusion_expr::var_provider::{is_system_variables, VarType};
use datafusion_expr::{
AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, TableSource,
WindowUDF,
AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilderConfig,
ScalarUDF, TableSource, WindowUDF,
};
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use datafusion_optimizer::{
Expand Down Expand Up @@ -130,6 +131,8 @@ pub struct SessionState {
analyzer: Analyzer,
/// Provides support for customizing the SQL planner, e.g. to add support for custom operators like `->>` or `?`
expr_planners: Vec<Arc<dyn ExprPlanner>>,
/// Provides support for customizing the SQL type coercion
type_coercions: Vec<Arc<dyn TypeCoercion>>,
/// Provides support for customizing the SQL type planning
type_planner: Option<Arc<dyn TypePlanner>>,
/// Responsible for optimizing a logical plan
Expand Down Expand Up @@ -196,6 +199,7 @@ impl Debug for SessionState {
.field("table_factories", &self.table_factories)
.field("function_factory", &self.function_factory)
.field("expr_planners", &self.expr_planners)
.field("type_coercions", &self.type_coercions)
.field("type_planner", &self.type_planner)
.field("query_planners", &self.query_planner)
.field("analyzer", &self.analyzer)
Expand All @@ -210,6 +214,12 @@ impl Debug for SessionState {
}
}

impl LogicalPlanBuilderConfig for SessionState {
fn get_type_coercions(&self) -> &[Arc<dyn TypeCoercion>] {
&self.type_coercions
}
}

#[async_trait]
impl Session for SessionState {
fn session_id(&self) -> &str {
Expand Down Expand Up @@ -816,6 +826,11 @@ impl SessionState {
&self.serializer_registry
}

/// Return the type coercion rules
pub fn type_coercions(&self) -> &Vec<Arc<dyn TypeCoercion>> {
&self.type_coercions
}

/// Return version of the cargo package that produced this query
pub fn version(&self) -> &str {
env!("CARGO_PKG_VERSION")
Expand Down Expand Up @@ -881,6 +896,7 @@ pub struct SessionStateBuilder {
session_id: Option<String>,
analyzer: Option<Analyzer>,
expr_planners: Option<Vec<Arc<dyn ExprPlanner>>>,
type_coercions: Option<Vec<Arc<dyn TypeCoercion>>>,
type_planner: Option<Arc<dyn TypePlanner>>,
optimizer: Option<Optimizer>,
physical_optimizers: Option<PhysicalOptimizer>,
Expand Down Expand Up @@ -917,6 +933,7 @@ impl SessionStateBuilder {
session_id: None,
analyzer: None,
expr_planners: None,
type_coercions: None,
type_planner: None,
optimizer: None,
physical_optimizers: None,
Expand Down Expand Up @@ -966,6 +983,7 @@ impl SessionStateBuilder {
session_id: None,
analyzer: Some(existing.analyzer),
expr_planners: Some(existing.expr_planners),
type_coercions: Some(existing.type_coercions),
type_planner: existing.type_planner,
optimizer: Some(existing.optimizer),
physical_optimizers: Some(existing.physical_optimizers),
Expand Down Expand Up @@ -1010,6 +1028,10 @@ impl SessionStateBuilder {
.get_or_insert_with(Vec::new)
.extend(SessionStateDefaults::default_expr_planners());

self.type_coercions
.get_or_insert_with(Vec::new)
.extend(SessionStateDefaults::default_type_coercions());

self.scalar_functions
.get_or_insert_with(Vec::new)
.extend(SessionStateDefaults::default_scalar_functions());
Expand Down Expand Up @@ -1318,6 +1340,7 @@ impl SessionStateBuilder {
session_id,
analyzer,
expr_planners,
type_coercions,
type_planner,
optimizer,
physical_optimizers,
Expand Down Expand Up @@ -1347,6 +1370,7 @@ impl SessionStateBuilder {
session_id: session_id.unwrap_or(Uuid::new_v4().to_string()),
analyzer: analyzer.unwrap_or_default(),
expr_planners: expr_planners.unwrap_or_default(),
type_coercions: type_coercions.unwrap_or_default(),
type_planner,
optimizer: optimizer.unwrap_or_default(),
physical_optimizers: physical_optimizers.unwrap_or_default(),
Expand Down Expand Up @@ -1622,6 +1646,12 @@ struct SessionContextProvider<'a> {
tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
}

impl LogicalPlanBuilderConfig for SessionContextProvider<'_> {
fn get_type_coercions(&self) -> &[Arc<dyn TypeCoercion>] {
&self.state.type_coercions
}
}

impl ContextProvider for SessionContextProvider<'_> {
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.state.expr_planners
Expand Down
6 changes: 6 additions & 0 deletions datafusion/core/src/execution/session_state_defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use datafusion_execution::config::SessionConfig;
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::type_coercion::{DefaultTypeCoercion, TypeCoercion};
use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -102,6 +103,11 @@ impl SessionStateDefaults {
expr_planners
}

/// Default type coercion used in DataFusion
pub fn default_type_coercions() -> Vec<Arc<dyn TypeCoercion>> {
vec![Arc::new(DefaultTypeCoercion)]
}

/// returns the list of default [`ScalarUDF']'s
pub fn default_scalar_functions() -> Vec<Arc<ScalarUDF>> {
#[cfg_attr(not(feature = "nested_expressions"), allow(unused_mut))]
Expand Down
Loading