Skip to content

Introduce FunctionRegistry dependency to optimize and rewrite rule #10714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 79 additions & 81 deletions datafusion-cli/Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2350,6 +2350,10 @@ impl OptimizerConfig for SessionState {
fn options(&self) -> &ConfigOptions {
self.config_options()
}

fn function_registry(&self) -> Option<&dyn FunctionRegistry> {
Some(self)
}
}

/// Create a new task context instance from SessionContext
Expand Down
7 changes: 6 additions & 1 deletion datafusion/execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@ pub mod config;
pub mod disk_manager;
pub mod memory_pool;
pub mod object_store;
pub mod registry;
pub mod runtime_env;
mod stream;
mod task;

pub mod registry {
pub use datafusion_expr::registry::{
FunctionRegistry, MemoryFunctionRegistry, SerializerRegistry,
};
}

pub use disk_manager::DiskManager;
pub use registry::FunctionRegistry;
pub use stream::{RecordBatchStream, SendableRecordBatchStream};
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub mod function;
pub mod groups_accumulator;
pub mod interval_arithmetic;
pub mod logical_plan;
pub mod registry;
pub mod simplify;
pub mod sort_properties;
pub mod tree_node;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

//! FunctionRegistry trait

use crate::expr_rewriter::FunctionRewrite;
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use datafusion_common::{not_impl_err, plan_datafusion_err, Result};
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use std::collections::HashMap;
use std::{collections::HashSet, sync::Arc};

Expand Down
1 change: 0 additions & 1 deletion datafusion/optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ async-trait = { workspace = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
datafusion-functions-aggregate = { workspace = true }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

datafusion-physical-expr = { workspace = true }
hashbrown = { workspace = true }
indexmap = { workspace = true }
Expand Down
5 changes: 5 additions & 0 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::collections::HashSet;
use std::sync::Arc;

use chrono::{DateTime, Utc};
use datafusion_expr::registry::FunctionRegistry;
use log::{debug, warn};

use datafusion_common::alias::AliasGenerator;
Expand Down Expand Up @@ -122,6 +123,10 @@ pub trait OptimizerConfig {
fn alias_generator(&self) -> Arc<AliasGenerator>;

fn options(&self) -> &ConfigOptions;

fn function_registry(&self) -> Option<&dyn FunctionRegistry> {
None
}
}

/// A standalone [`OptimizerConfig`] that can be used independently
Expand Down
69 changes: 14 additions & 55 deletions datafusion/optimizer/src/replace_distinct_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Column, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr_rewriter::normalize_cols;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{col, LogicalPlanBuilder};
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
use datafusion_functions_aggregate::first_last::first_value;

/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
///
Expand Down Expand Up @@ -73,7 +73,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Distinct(Distinct::All(input)) => {
Expand All @@ -95,9 +95,18 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
let expr_cnt = on_expr.len();

// Construct the aggregation expression to be used to fetch the selected expressions.
let aggr_expr = select_expr
.into_iter()
.map(|e| first_value(vec![e], false, None, sort_expr.clone(), None));
let first_value_udaf =
config.function_registry().unwrap().udaf("first_value")?;
let aggr_expr = select_expr.into_iter().map(|e| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is the reason for this PR I think -- to avoid the hard coded dependency on first_value...

Expr::AggregateFunction(AggregateFunction::new_udf(
first_value_udaf.clone(),
vec![e],
false,
None,
sort_expr.clone(),
None,
))
});

let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
let group_expr = normalize_cols(on_expr, input.as_ref())?;
Expand Down Expand Up @@ -163,53 +172,3 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
Some(BottomUp)
}
}

#[cfg(test)]
mod tests {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to slt

use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
use crate::test::{assert_optimized_plan_eq, test_table_scan};
use datafusion_expr::{col, LogicalPlanBuilder};
use std::sync::Arc;

#[test]
fn replace_distinct() -> datafusion_common::Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.distinct()?
.build()?;

let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\
\n Projection: test.a, test.b\
\n TableScan: test";

assert_optimized_plan_eq(
Arc::new(ReplaceDistinctWithAggregate::new()),
plan,
expected,
)
}

#[test]
fn replace_distinct_on() -> datafusion_common::Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.distinct_on(
vec![col("a")],
vec![col("b")],
Some(vec![col("a").sort(false, true), col("c").sort(true, false)]),
)?
.build()?;

let expected = "Projection: first_value(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\
\n Sort: test.a DESC NULLS FIRST\
\n Aggregate: groupBy=[[test.a]], aggr=[[first_value(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\
\n TableScan: test";

assert_optimized_plan_eq(
Arc::new(ReplaceDistinctWithAggregate::new()),
plan,
expected,
)
}
}
36 changes: 36 additions & 0 deletions datafusion/sqllogictest/test_files/distinct_on.slt
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,39 @@ LIMIT 3;
-25 15295
45 15673
-72 -11122

# test distinct on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

statement ok
create table t(a int, b int, c int) as values (1, 2, 3);

statement ok
set datafusion.explain.logical_plan_only = true;

query TT
explain select distinct on (a) b from t order by a desc, c;
----
logical_plan
01)Projection: first_value(t.b) ORDER BY [t.a DESC NULLS FIRST, t.c ASC NULLS LAST] AS b
02)--Sort: t.a DESC NULLS FIRST
03)----Aggregate: groupBy=[[t.a]], aggr=[[first_value(t.b) ORDER BY [t.a DESC NULLS FIRST, t.c ASC NULLS LAST]]]
04)------TableScan: t projection=[a, b, c]

statement ok
drop table t;

# test distinct
statement ok
create table t(a int, b int) as values (1, 2);

statement ok
set datafusion.explain.logical_plan_only = true;

query TT
explain select distinct a, b from t;
----
logical_plan
01)Aggregate: groupBy=[[t.a, t.b]], aggr=[[]]
02)--TableScan: t projection=[a, b]

statement ok
drop table t;