From 876959d7f4686ef0e73cbae38fec8ce1f01c3599 Mon Sep 17 00:00:00 2001 From: wiedld Date: Tue, 21 Jan 2025 22:50:55 -0800 Subject: [PATCH 1/4] feat(13525): permit user-defined invariants on logical plan extensions --- .../tests/user_defined/user_defined_plan.rs | 106 +++++++++++++++++- datafusion/expr/src/logical_plan/extension.rs | 31 +++++ .../expr/src/logical_plan/invariants.rs | 57 +++++++++- datafusion/expr/src/logical_plan/mod.rs | 4 +- 4 files changed, 193 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 487063642345..4ba9f21b641c 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -59,6 +59,7 @@ //! use std::fmt::Debug; +use std::hash::Hash; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; @@ -93,7 +94,7 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; -use datafusion_expr::{FetchType, Projection, SortExpr}; +use datafusion_expr::{FetchType, Invariant, InvariantLevel, Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -295,7 +296,63 @@ async fn topk_plan() -> Result<()> { Ok(()) } +#[tokio::test] +/// Run invariant checks on the logical plan extension [`TopKPlanNode`]. +async fn topk_invariants() -> Result<()> { + // Test: pass an InvariantLevel::Always + let pass = InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Always, + }; + let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?; + run_and_compare_query(ctx, "Topk context").await?; + + // Test: fail an InvariantLevel::Always + let fail = InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Always, + }; + let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?; + matches!( + &*run_and_compare_query(ctx, "Topk context") + .await + .unwrap_err() + .message(), + "node fails check, such as improper inputs" + ); + + // Test: pass an InvariantLevel::Executable + let pass = InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Executable, + }; + let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?; + run_and_compare_query(ctx, "Topk context").await?; + + // Test: fail an InvariantLevel::Executable + let fail = InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Executable, + }; + let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?; + matches!( + &*run_and_compare_query(ctx, "Topk context") + .await + .unwrap_err() + .message(), + "node fails check, such as improper inputs" + ); + + Ok(()) +} + fn make_topk_context() -> SessionContext { + make_topk_context_with_invariants(None) +} + +fn make_topk_context_with_invariants( + invariant_mock: Option, +) -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); let state = SessionStateBuilder::new() @@ -303,7 +360,7 @@ fn make_topk_context() -> SessionContext { .with_runtime_env(runtime) .with_default_features() .with_query_planner(Arc::new(TopKQueryPlanner {})) - .with_optimizer_rule(Arc::new(TopKOptimizerRule {})) + .with_optimizer_rule(Arc::new(TopKOptimizerRule { invariant_mock })) .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) .build(); SessionContext::new_with_state(state) @@ -336,7 +393,10 @@ impl QueryPlanner for TopKQueryPlanner { } #[derive(Default, Debug)] -struct TopKOptimizerRule {} +struct TopKOptimizerRule { + /// A testing-only hashable fixture. + invariant_mock: Option, +} impl OptimizerRule for TopKOptimizerRule { fn name(&self) -> &str { @@ -380,6 +440,7 @@ impl OptimizerRule for TopKOptimizerRule { k: fetch, input: input.as_ref().clone(), expr: expr[0].clone(), + invariant_mock: self.invariant_mock.clone(), }), }))); } @@ -396,6 +457,10 @@ struct TopKPlanNode { /// The sort expression (this example only supports a single sort /// expr) expr: SortExpr, + + /// A testing-only hashable fixture. + /// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`]. + invariant_mock: Option, } impl Debug for TopKPlanNode { @@ -406,6 +471,20 @@ impl Debug for TopKPlanNode { } } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +struct InvariantMock { + should_fail_invariant: bool, + kind: InvariantLevel, +} + +fn invariant_helper_mock_ok(_: &LogicalPlan) -> Result<()> { + Ok(()) +} + +fn invariant_helper_mock_fails(_: &LogicalPlan) -> Result<()> { + internal_err!("node fails check, such as improper inputs") +} + impl UserDefinedLogicalNodeCore for TopKPlanNode { fn name(&self) -> &str { "TopK" @@ -420,6 +499,26 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { self.input.schema() } + fn invariants(&self) -> Vec { + if let Some(InvariantMock { + should_fail_invariant, + kind, + }) = self.invariant_mock.clone() + { + if should_fail_invariant { + return vec![Invariant { + kind, + fun: Arc::new(invariant_helper_mock_fails), + }]; + } + return vec![Invariant { + kind, + fun: Arc::new(invariant_helper_mock_ok), + }]; + } + vec![] // same as default impl + } + fn expressions(&self) -> Vec { vec![self.expr.expr.clone()] } @@ -440,6 +539,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { k: self.k, input: inputs.swap_remove(0), expr: self.expr.with_expr(exprs.swap_remove(0)), + invariant_mock: self.invariant_mock.clone(), }) } diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index 19d4cb3db9ce..6a3d0dc97a63 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -22,6 +22,9 @@ use std::cmp::Ordering; use std::hash::{Hash, Hasher}; use std::{any::Any, collections::HashSet, fmt, sync::Arc}; +use super::invariants::Invariant; +use super::InvariantLevel; + /// This defines the interface for [`LogicalPlan`] nodes that can be /// used to extend DataFusion with custom relational operators. /// @@ -54,6 +57,22 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; + /// Return the list of invariants. + /// + /// Implementing this function enables the user to define the + /// invariants for a given logical plan extension. + fn invariants(&self) -> Vec { + vec![] + } + + /// Perform check of invariants for the extension node. + fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> { + self.invariants() + .into_iter() + .filter(|inv| check == inv.kind) + .try_for_each(|inv| inv.check(plan)) + } + /// Returns all expressions in the current logical plan node. This should /// not include expressions of any inputs (aka non-recursively). /// @@ -244,6 +263,14 @@ pub trait UserDefinedLogicalNodeCore: /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; + /// Return the list of invariants. + /// + /// Implementing this function enables the user to define the + /// invariants for a given logical plan extension. + fn invariants(&self) -> Vec { + vec![] + } + /// Returns all expressions in the current logical plan node. This /// should not include expressions of any inputs (aka /// non-recursively). These expressions are used for optimizer @@ -336,6 +363,10 @@ impl UserDefinedLogicalNode for T { self.schema() } + fn invariants(&self) -> Vec { + self.invariants() + } + fn expressions(&self) -> Vec { self.expressions() } diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index bde4acaae562..4959dafdef99 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use datafusion_common::{ internal_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, @@ -28,6 +30,24 @@ use crate::{ Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, }; +use super::Extension; + +pub type InvariantFn = Arc Result<()> + Send + Sync>; + +#[derive(Clone)] +pub struct Invariant { + pub kind: InvariantLevel, + pub fun: InvariantFn, +} + +impl Invariant { + /// Return an error if invariant does not hold true. + pub fn check(&self, plan: &LogicalPlan) -> Result<()> { + (self.fun)(plan) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum InvariantLevel { /// Invariants that are always true in DataFusion `LogicalPlan`s /// such as the number of expected children and no duplicated output fields @@ -41,6 +61,7 @@ pub enum InvariantLevel { Executable, } +/// Apply the [`InvariantLevel::Always`] check at the root plan node only. pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> { // Refer to assert_unique_field_names(plan)?; @@ -48,12 +69,46 @@ pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> { Ok(()) } +/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`] +/// as well as the less stringent [`InvariantLevel::Always`] checks. pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> { + // Always invariants assert_always_invariants(plan)?; + assert_valid_extension_nodes(plan, InvariantLevel::Always)?; + + // Executable invariants + assert_valid_extension_nodes(plan, InvariantLevel::Executable)?; assert_valid_semantic_plan(plan)?; Ok(()) } +/// Asserts that the query plan, and subplan, extension nodes have valid invariants. +/// +/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode) +/// for more details of user-provided extension node invariants. +fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> { + plan.apply_with_subqueries(|plan: &LogicalPlan| { + if let LogicalPlan::Extension(Extension { node }) = plan { + node.check_invariants(check, plan)?; + } + plan.apply_expressions(|expr| { + // recursively look for subqueries + expr.apply(|expr| { + match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + assert_valid_extension_nodes(&subquery.subquery, check)?; + } + _ => {} + }; + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| ()) +} + /// Returns an error if plan, and subplans, do not have unique fields. /// /// This invariant is subject to change. @@ -87,7 +142,7 @@ pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Resul /// Asserts that the subqueries are structured properly with valid node placement. /// -/// Refer to [`check_subquery_expr`] for more details. +/// Refer to [`check_subquery_expr`] for more details of the internal invariants. fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> { plan.apply_with_subqueries(|plan: &LogicalPlan| { plan.apply_expressions(|expr| { diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 404941378663..48cd19ed13ed 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -21,7 +21,9 @@ pub mod display; pub mod dml; mod extension; pub(crate) mod invariants; -pub use invariants::{assert_expected_schema, check_subquery_expr, InvariantLevel}; +pub use invariants::{ + assert_expected_schema, check_subquery_expr, Invariant, InvariantFn, InvariantLevel, +}; mod plan; mod statement; pub mod tree_node; From 383644478f4a56e83e245b702bcb9c214d242ed1 Mon Sep 17 00:00:00 2001 From: wiedld Date: Wed, 22 Jan 2025 03:40:19 -0800 Subject: [PATCH 2/4] test(13525): demonstrate extension node invariants catching improper mutation during an optimizer pass --- .../tests/user_defined/user_defined_plan.rs | 99 +++++++++++++++++++ .../expr/src/logical_plan/invariants.rs | 6 +- datafusion/expr/src/logical_plan/plan.rs | 5 +- 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 4ba9f21b641c..dcfb30198725 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -346,6 +346,62 @@ async fn topk_invariants() -> Result<()> { Ok(()) } +#[tokio::test] +async fn topk_invariants_after_invalid_mutation() -> Result<()> { + // CONTROL + // Build a valid topK plan. + let config = SessionConfig::new().with_target_partitions(48); + let runtime = Arc::new(RuntimeEnv::default()); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(TopKQueryPlanner {})) + // 1. adds a valid TopKPlanNode + .with_optimizer_rule(Arc::new(TopKOptimizerRule { + invariant_mock: Some(InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Always, + }), + })) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); + let ctx = setup_table(SessionContext::new_with_state(state)).await?; + run_and_compare_query(ctx, "Topk context").await?; + + // Test + // Build a valid topK plan. + // Then have an invalid mutation in an optimizer run. + let config = SessionConfig::new().with_target_partitions(48); + let runtime = Arc::new(RuntimeEnv::default()); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(TopKQueryPlanner {})) + // 1. adds a valid TopKPlanNode + .with_optimizer_rule(Arc::new(TopKOptimizerRule { + invariant_mock: Some(InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Always, + }), + })) + // 2. break the TopKPlanNode + .with_optimizer_rule(Arc::new(OptimizerMakeExtensionNodeInvalid {})) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); + let ctx = setup_table(SessionContext::new_with_state(state)).await?; + matches!( + &*run_and_compare_query(ctx, "Topk context") + .await + .unwrap_err() + .message(), + "node fails check, such as improper inputs" + ); + + Ok(()) +} + fn make_topk_context() -> SessionContext { make_topk_context_with_invariants(None) } @@ -366,6 +422,49 @@ fn make_topk_context_with_invariants( SessionContext::new_with_state(state) } +#[derive(Debug)] +struct OptimizerMakeExtensionNodeInvalid; + +impl OptimizerRule for OptimizerMakeExtensionNodeInvalid { + fn name(&self) -> &str { + "OptimizerMakeExtensionNodeInvalid" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + // Example rewrite pass which impacts validity of the extension node. + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + if let LogicalPlan::Extension(Extension { node }) = &plan { + if let Some(prev) = node.as_any().downcast_ref::() { + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: prev.k, + input: prev.input.clone(), + expr: prev.expr.clone(), + // In a real use case, this rewriter could have change the number of inputs, etc + invariant_mock: Some(InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Always, + }), + }), + }))); + } + }; + + Ok(Transformed::no(plan)) + } +} + // ------ The implementation of the TopK code follows ----- #[derive(Debug)] diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 4959dafdef99..fb50fbe42e81 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -61,8 +61,8 @@ pub enum InvariantLevel { Executable, } -/// Apply the [`InvariantLevel::Always`] check at the root plan node only. -pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> { +/// Apply the [`InvariantLevel::Always`] check at the current plan node only. +pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> { // Refer to assert_unique_field_names(plan)?; @@ -73,7 +73,7 @@ pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> { /// as well as the less stringent [`InvariantLevel::Always`] checks. pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> { // Always invariants - assert_always_invariants(plan)?; + assert_always_invariants_at_current_node(plan)?; assert_valid_extension_nodes(plan, InvariantLevel::Always)?; // Executable invariants diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 24fb0609b0fe..ebad1dcf9de4 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,7 +25,8 @@ use std::sync::{Arc, LazyLock}; use super::dml::CopyTo; use super::invariants::{ - assert_always_invariants, assert_executable_invariants, InvariantLevel, + assert_always_invariants_at_current_node, assert_executable_invariants, + InvariantLevel, }; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; @@ -1137,7 +1138,7 @@ impl LogicalPlan { /// checks that the plan conforms to the listed invariant level, returning an Error if not pub fn check_invariants(&self, check: InvariantLevel) -> Result<()> { match check { - InvariantLevel::Always => assert_always_invariants(self), + InvariantLevel::Always => assert_always_invariants_at_current_node(self), InvariantLevel::Executable => assert_executable_invariants(self), } } From 0cda03068f2173353a4b32260648b5d3a62d14da Mon Sep 17 00:00:00 2001 From: wiedld Date: Thu, 30 Jan 2025 19:20:09 -0800 Subject: [PATCH 3/4] chore: update docs --- datafusion/expr/src/logical_plan/invariants.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index fb50fbe42e81..eb6e5f69efd5 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -62,6 +62,8 @@ pub enum InvariantLevel { } /// Apply the [`InvariantLevel::Always`] check at the current plan node only. +/// +/// This does not recurs to any child nodes. pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> { // Refer to assert_unique_field_names(plan)?; From 159d62d4f176f34f29ae4b137e0c82ff44fd5dd2 Mon Sep 17 00:00:00 2001 From: wiedld Date: Fri, 31 Jan 2025 10:58:09 -0800 Subject: [PATCH 4/4] refactor: remove the extra Invariant interface around an FnMut, since it doesn't make sense for the extension node's checks --- .../tests/user_defined/user_defined_plan.rs | 25 +++----------- datafusion/expr/src/logical_plan/extension.rs | 33 +++++++------------ .../expr/src/logical_plan/invariants.rs | 17 ---------- datafusion/expr/src/logical_plan/mod.rs | 4 +-- .../tests/cases/roundtrip_logical_plan.rs | 12 +++++-- 5 files changed, 27 insertions(+), 64 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index dcfb30198725..fae4b2cd82ab 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -94,7 +94,7 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; -use datafusion_expr::{FetchType, Invariant, InvariantLevel, Projection, SortExpr}; +use datafusion_expr::{FetchType, InvariantLevel, Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -576,14 +576,6 @@ struct InvariantMock { kind: InvariantLevel, } -fn invariant_helper_mock_ok(_: &LogicalPlan) -> Result<()> { - Ok(()) -} - -fn invariant_helper_mock_fails(_: &LogicalPlan) -> Result<()> { - internal_err!("node fails check, such as improper inputs") -} - impl UserDefinedLogicalNodeCore for TopKPlanNode { fn name(&self) -> &str { "TopK" @@ -598,24 +590,17 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { self.input.schema() } - fn invariants(&self) -> Vec { + fn check_invariants(&self, check: InvariantLevel, _plan: &LogicalPlan) -> Result<()> { if let Some(InvariantMock { should_fail_invariant, kind, }) = self.invariant_mock.clone() { - if should_fail_invariant { - return vec![Invariant { - kind, - fun: Arc::new(invariant_helper_mock_fails), - }]; + if should_fail_invariant && check == kind { + return internal_err!("node fails check, such as improper inputs"); } - return vec![Invariant { - kind, - fun: Arc::new(invariant_helper_mock_ok), - }]; } - vec![] // same as default impl + Ok(()) } fn expressions(&self) -> Vec { diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index 6a3d0dc97a63..be7153cc4eaa 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -22,7 +22,6 @@ use std::cmp::Ordering; use std::hash::{Hash, Hasher}; use std::{any::Any, collections::HashSet, fmt, sync::Arc}; -use super::invariants::Invariant; use super::InvariantLevel; /// This defines the interface for [`LogicalPlan`] nodes that can be @@ -57,21 +56,8 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; - /// Return the list of invariants. - /// - /// Implementing this function enables the user to define the - /// invariants for a given logical plan extension. - fn invariants(&self) -> Vec { - vec![] - } - /// Perform check of invariants for the extension node. - fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> { - self.invariants() - .into_iter() - .filter(|inv| check == inv.kind) - .try_for_each(|inv| inv.check(plan)) - } + fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>; /// Returns all expressions in the current logical plan node. This should /// not include expressions of any inputs (aka non-recursively). @@ -263,12 +249,15 @@ pub trait UserDefinedLogicalNodeCore: /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; - /// Return the list of invariants. + /// Perform check of invariants for the extension node. /// - /// Implementing this function enables the user to define the - /// invariants for a given logical plan extension. - fn invariants(&self) -> Vec { - vec![] + /// This is the default implementation for extension nodes. + fn check_invariants( + &self, + _check: InvariantLevel, + _plan: &LogicalPlan, + ) -> Result<()> { + Ok(()) } /// Returns all expressions in the current logical plan node. This @@ -363,8 +352,8 @@ impl UserDefinedLogicalNode for T { self.schema() } - fn invariants(&self) -> Vec { - self.invariants() + fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> { + self.check_invariants(check, plan) } fn expressions(&self) -> Vec { diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index eb6e5f69efd5..c8f1fcd2d90b 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use datafusion_common::{ internal_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, @@ -32,21 +30,6 @@ use crate::{ use super::Extension; -pub type InvariantFn = Arc Result<()> + Send + Sync>; - -#[derive(Clone)] -pub struct Invariant { - pub kind: InvariantLevel, - pub fun: InvariantFn, -} - -impl Invariant { - /// Return an error if invariant does not hold true. - pub fn check(&self, plan: &LogicalPlan) -> Result<()> { - (self.fun)(plan) - } -} - #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum InvariantLevel { /// Invariants that are always true in DataFusion `LogicalPlan`s diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 48cd19ed13ed..404941378663 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -21,9 +21,7 @@ pub mod display; pub mod dml; mod extension; pub(crate) mod invariants; -pub use invariants::{ - assert_expected_schema, check_subquery_expr, Invariant, InvariantFn, InvariantLevel, -}; +pub use invariants::{assert_expected_schema, check_subquery_expr, InvariantLevel}; mod plan; mod statement; pub mod tree_node; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7045729493b1..5fb357dfcd23 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -32,8 +32,8 @@ use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{ - Extension, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, - Values, Volatility, + Extension, InvariantLevel, LogicalPlan, PartitionEvaluator, Repartition, + UserDefinedLogicalNode, Values, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -111,6 +111,14 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { &self.empty_schema } + fn check_invariants( + &self, + _check: InvariantLevel, + _plan: &LogicalPlan, + ) -> Result<()> { + Ok(()) + } + fn expressions(&self) -> Vec { vec![] }