diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d29ccdc6a7e9..0c7769639ac7 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -21,13 +21,13 @@ use datafusion_expr::planner::{ }; use sqlparser::ast::{ AccessExpr, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, - DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, + DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, Ident, MapEntry, StructField, Subscript, TrimWhereField, Value, ValueWithSpan, }; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, - ScalarValue, + internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Diagnostic, + Result, ScalarValue, Span, }; use datafusion_expr::expr::ScalarFunction; @@ -86,6 +86,42 @@ impl SqlToRel<'_, S> { StackEntry::SQLExpr(sql_expr) => { match *sql_expr { SQLExpr::BinaryOp { left, op, right } => { + // Detect if there is "= Null" in SQL + if op == BinaryOperator::Eq { + if let SQLExpr::Value(ValueWithSpan { + value: Value::Null, + span: null_span, + }) = *right + { + let left_span = match &*left { + SQLExpr::Identifier(Ident { span, .. }) => *span, + // In this case, we expect left to be + // Indentifier. Just to make the code + // more robust, we'll make left_span + // equals to null_span otherwise. + _ => null_span, + }; + let combined_span = Span { + start: Into::into(left_span.start), + end: Into::into(null_span.end), + }; + + let diagnostic = Diagnostic::new_warning( + "Ambiguous NULL comparison".to_string(), + Some(combined_span), + ) + .with_help( + "Use IS NULL instead of = NULL", + Some(Span { + start: Into::into(null_span.start), + end: Into::into(null_span.end), + }), + ); + + self.warnings.borrow_mut().push(diagnostic); + } + } + // Note the order that we push the entries to the stack // is important. We want to visit the left node first. stack.push(StackEntry::Operator(op)); @@ -1174,7 +1210,7 @@ mod tests { use sqlparser::parser::Parser; use datafusion_common::config::ConfigOptions; - use datafusion_common::TableReference; + use datafusion_common::{Location, TableReference}; use datafusion_expr::logical_plan::builder::LogicalTableSource; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; @@ -1316,4 +1352,96 @@ mod tests { assert!(matches!(expr, Expr::Alias(_))); } + + // Helper to parse SQL expressions + fn parse_expr(sql: &str) -> SQLExpr { + let dialect = GenericDialect {}; + Parser::new(&dialect) + .try_with_sql(sql) + .unwrap() + .parse_expr() + .unwrap() + } + + #[test] + fn test_single_null_comparison() { + let context = TestContextProvider::new(); + let planner = SqlToRel::new(&context); + + // Test single = NULL case + let expr = parse_expr("password = NULL"); + let _ = planner + .sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + &mut PlannerContext::new(), + ) + .unwrap(); + + let warnings = planner.warnings.take(); + assert_eq!(warnings.len(), 1, "Should detect 1 warning"); + let warning = &warnings[0]; + assert_eq!(warning.message, "Ambiguous NULL comparison"); + + assert_eq!( + warning.span, + Some(Span { + start: Location { line: 1, column: 1 }, + end: Location { + line: 1, + column: 16 + } + }) + ); + + assert_eq!(warning.helps.len(), 1); + let help = &warning.helps[0]; + assert_eq!(help.message, "Use IS NULL instead of = NULL"); + } + + #[test] + fn test_multiple_null_comparisons() { + let context = TestContextProvider::new(); + let planner = SqlToRel::new(&context); + + // Test multiple = NULL cases + let expr = parse_expr("(name = NULL) OR (age = NULL)"); + let _ = planner + .sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + &mut PlannerContext::new(), + ) + .unwrap(); + + let warnings = planner.warnings.take(); + assert_eq!(warnings.len(), 2, "Should detect 2 warnings"); + + let first = &warnings[0]; + assert_eq!( + first.span, + Some(Span { + start: Location { line: 1, column: 2 }, + end: Location { + line: 1, + column: 13 + } + }) + ); + + let second = &warnings[1]; + assert_eq!( + second.span, + Some(Span { + start: Location { + line: 1, + column: 19 + }, + end: Location { + line: 1, + column: 29 + } + }) + ); + } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 73d136d7d1cc..31f60306a3b1 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -16,6 +16,7 @@ // under the License. //! [`SqlToRel`]: SQL Query Planner (produces [`LogicalPlan`] from SQL AST) +use std::cell::RefCell; use std::collections::HashMap; use std::sync::Arc; use std::vec; @@ -337,6 +338,7 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) ident_normalizer: IdentNormalizer, + pub(crate) warnings: RefCell>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -359,6 +361,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { context_provider, options, ident_normalizer: IdentNormalizer::new(ident_normalize), + warnings: RefCell::new(Vec::new()), } }